Merge branch 'ggerganov:master' into master
This commit is contained in:
commit
3c0b830808
37 changed files with 3523 additions and 1781 deletions
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
|
@ -840,7 +840,7 @@ jobs:
|
||||||
id: pack_artifacts
|
id: pack_artifacts
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
run: |
|
run: |
|
||||||
7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip .\build\bin\*
|
7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/*
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
|
|
|
@ -1170,6 +1170,7 @@ add_library(llama
|
||||||
llama.h
|
llama.h
|
||||||
unicode.h
|
unicode.h
|
||||||
unicode.cpp
|
unicode.cpp
|
||||||
|
unicode-data.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_include_directories(llama PUBLIC .)
|
target_include_directories(llama PUBLIC .)
|
||||||
|
|
5
Makefile
5
Makefile
|
@ -678,7 +678,10 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h ggml-common.h
|
||||||
unicode.o: unicode.cpp unicode.h
|
unicode.o: unicode.cpp unicode.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o
|
unicode-data.o: unicode-data.cpp unicode-data.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o
|
||||||
|
|
||||||
llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
|
llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
|
@ -32,6 +32,7 @@ let package = Package(
|
||||||
"ggml.c",
|
"ggml.c",
|
||||||
"llama.cpp",
|
"llama.cpp",
|
||||||
"unicode.cpp",
|
"unicode.cpp",
|
||||||
|
"unicode-data.cpp",
|
||||||
"ggml-alloc.c",
|
"ggml-alloc.c",
|
||||||
"ggml-backend.c",
|
"ggml-backend.c",
|
||||||
"ggml-quants.c",
|
"ggml-quants.c",
|
||||||
|
|
10
README.md
10
README.md
|
@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||||
|
|
||||||
### Recent API changes
|
### Recent API changes
|
||||||
|
|
||||||
|
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
|
||||||
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
|
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
|
||||||
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
|
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
|
||||||
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
|
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
|
||||||
|
@ -631,6 +632,15 @@ Building the program with BLAS support may lead to some performance improvements
|
||||||
|
|
||||||
- #### Vulkan
|
- #### Vulkan
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
>
|
||||||
|
> Vulkan support has been broken in https://github.com/ggerganov/llama.cpp/pull/6122
|
||||||
|
> due to relying on `GGML_OP_GET_ROWS` which is not yet properly supported by the Vulkan backend,
|
||||||
|
> but should be fixed relatively soon (possibly in https://github.com/ggerganov/llama.cpp/pull/6155
|
||||||
|
> (ref: https://github.com/ggerganov/llama.cpp/pull/6122#issuecomment-2015327635)).
|
||||||
|
>
|
||||||
|
> Meanwhile, if you want to use the Vulkan backend, you should use the commit right before the breaking change, https://github.com/ggerganov/llama.cpp/commit/55c1b2a3bbd470e9e2a3a0618b92cf64a885f806
|
||||||
|
|
||||||
**With docker**:
|
**With docker**:
|
||||||
|
|
||||||
You don't need to install Vulkan SDK. It will be installed inside the container.
|
You don't need to install Vulkan SDK. It will be installed inside the container.
|
||||||
|
|
15
build.zig
15
build.zig
|
@ -116,6 +116,7 @@ pub fn build(b: *std.build.Builder) !void {
|
||||||
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
|
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
|
||||||
const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
|
const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
|
||||||
const unicode = make.obj("unicode", "unicode.cpp");
|
const unicode = make.obj("unicode", "unicode.cpp");
|
||||||
|
const unicode_data = make.obj("unicode-data", "unicode-data.cpp");
|
||||||
const llama = make.obj("llama", "llama.cpp");
|
const llama = make.obj("llama", "llama.cpp");
|
||||||
const buildinfo = make.obj("common", "common/build-info.cpp");
|
const buildinfo = make.obj("common", "common/build-info.cpp");
|
||||||
const common = make.obj("common", "common/common.cpp");
|
const common = make.obj("common", "common/common.cpp");
|
||||||
|
@ -127,14 +128,14 @@ pub fn build(b: *std.build.Builder) !void {
|
||||||
const clip = make.obj("clip", "examples/llava/clip.cpp");
|
const clip = make.obj("clip", "examples/llava/clip.cpp");
|
||||||
const llava = make.obj("llava", "examples/llava/llava.cpp");
|
const llava = make.obj("llava", "examples/llava/llava.cpp");
|
||||||
|
|
||||||
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, console, grammar_parser });
|
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, console, grammar_parser });
|
||||||
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo });
|
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
|
||||||
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo });
|
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
|
||||||
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo });
|
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
|
||||||
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
|
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train });
|
||||||
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
|
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train });
|
||||||
|
|
||||||
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava });
|
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava });
|
||||||
if (server.target.isWindows()) {
|
if (server.target.isWindows()) {
|
||||||
server.linkSystemLibrary("ws2_32");
|
server.linkSystemLibrary("ws2_32");
|
||||||
}
|
}
|
||||||
|
|
|
@ -331,7 +331,7 @@ class Model(ABC):
|
||||||
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
||||||
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
||||||
|
|
||||||
for token_id in range(vocab_size):
|
for token_id in range(tokenizer.vocab_size()):
|
||||||
piece = tokenizer.id_to_piece(token_id)
|
piece = tokenizer.id_to_piece(token_id)
|
||||||
text = piece.encode("utf-8")
|
text = piece.encode("utf-8")
|
||||||
score = tokenizer.get_score(token_id)
|
score = tokenizer.get_score(token_id)
|
||||||
|
@ -356,10 +356,14 @@ class Model(ABC):
|
||||||
added_tokens_json = json.load(f)
|
added_tokens_json = json.load(f)
|
||||||
|
|
||||||
for key in added_tokens_json:
|
for key in added_tokens_json:
|
||||||
tokens.append(key.encode("utf-8"))
|
key = key.encode("utf-8")
|
||||||
|
if key not in tokens:
|
||||||
|
tokens.append(key)
|
||||||
scores.append(-1000.0)
|
scores.append(-1000.0)
|
||||||
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
||||||
|
|
||||||
|
assert len(tokens) == vocab_size
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("llama")
|
self.gguf_writer.add_tokenizer_model("llama")
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_scores(scores)
|
self.gguf_writer.add_token_scores(scores)
|
||||||
|
|
|
@ -61,6 +61,8 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
params.embedding = true;
|
params.embedding = true;
|
||||||
|
// For non-causal models, batch size must be equal to ubatch size
|
||||||
|
params.n_ubatch = params.n_batch;
|
||||||
|
|
||||||
print_build_info();
|
print_build_info();
|
||||||
|
|
||||||
|
@ -114,7 +116,9 @@ int main(int argc, char ** argv) {
|
||||||
for (const auto & prompt : prompts) {
|
for (const auto & prompt : prompts) {
|
||||||
auto inp = ::llama_tokenize(ctx, prompt, true, false);
|
auto inp = ::llama_tokenize(ctx, prompt, true, false);
|
||||||
if (inp.size() > n_batch) {
|
if (inp.size() > n_batch) {
|
||||||
inp.resize(n_batch);
|
fprintf(stderr, "%s: error: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
|
||||||
|
__func__, (long long int) inp.size(), (long long int) n_batch);
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
inputs.push_back(inp);
|
inputs.push_back(inp);
|
||||||
}
|
}
|
||||||
|
|
|
@ -424,6 +424,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
|
||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use batch.logits to save computations instead of relying on logits_all == true
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -132,7 +132,6 @@ int main(int argc, char ** argv) {
|
||||||
llama_context * ctx = NULL;
|
llama_context * ctx = NULL;
|
||||||
|
|
||||||
// load the target model
|
// load the target model
|
||||||
params.logits_all = true;
|
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
// load the prompts from an external file if there are any
|
// load the prompts from an external file if there are any
|
||||||
|
|
|
@ -380,6 +380,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||||
|
// TODO: use llama_batch.logits instead of relying on logits_all == true
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
|
@ -552,6 +553,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
|
int n_outputs = 0;
|
||||||
|
|
||||||
batch.n_tokens = 0;
|
batch.n_tokens = 0;
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
int seq_start = batch_start + seq*n_ctx;
|
int seq_start = batch_start + seq*n_ctx;
|
||||||
|
@ -571,6 +574,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
batch.n_seq_id[idx] = 1;
|
batch.n_seq_id[idx] = 1;
|
||||||
batch.seq_id [idx][0] = seq;
|
batch.seq_id [idx][0] = seq;
|
||||||
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||||
|
|
||||||
|
n_outputs += batch.logits[idx] != 0;
|
||||||
}
|
}
|
||||||
batch.n_tokens += batch_size;
|
batch.n_tokens += batch_size;
|
||||||
|
|
||||||
|
@ -583,9 +588,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (num_batches > 1) {
|
if (num_batches > 1 && n_outputs > 0) {
|
||||||
const auto * batch_logits = llama_get_logits(ctx);
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -604,14 +609,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||||
|
|
||||||
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||||
if (!params.logits_file.empty()) {
|
if (!params.logits_file.empty()) {
|
||||||
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
|
process_logits(logits_stream, n_vocab, all_logits,
|
||||||
tokens_data, n_ctx - 1 - first,
|
tokens_data, n_ctx - 1 - first,
|
||||||
workers, log_probs, nll, nll2);
|
workers, log_probs, nll, nll2);
|
||||||
} else {
|
} else {
|
||||||
process_logits(n_vocab, all_logits + first*n_vocab,
|
process_logits(n_vocab, all_logits,
|
||||||
tokens_data, n_ctx - 1 - first,
|
tokens_data, n_ctx - 1 - first,
|
||||||
workers, nll, nll2,
|
workers, nll, nll2,
|
||||||
logit_history.data() + start + seq*n_ctx + first,
|
logit_history.data() + start + seq*n_ctx + first,
|
||||||
|
@ -652,6 +658,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
|
||||||
|
int prev_outputs = 0;
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
|
|
||||||
|
@ -672,7 +679,14 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
int n_outputs = 0;
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
n_outputs += batch_view.logits[i] != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
|
||||||
|
|
||||||
|
prev_outputs += n_outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -779,7 +793,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
size_t ending_logprob_count[4];
|
size_t ending_logprob_count[4];
|
||||||
double ending_logprob[4];
|
double ending_logprob[4];
|
||||||
|
|
||||||
size_t i_batch; // starting index in the llama_batch
|
size_t i_logits; // starting index of logits in the llama_batch
|
||||||
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
||||||
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
|
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
|
||||||
std::vector<llama_token> seq_tokens[4];
|
std::vector<llama_token> seq_tokens[4];
|
||||||
|
@ -844,9 +858,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
const int max_tasks_per_batch = 32;
|
const int max_tasks_per_batch = 32;
|
||||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, 4);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
|
// TODO: this could be made smaller; it's currently the worst-case size
|
||||||
std::vector<float> batch_logits(n_vocab*n_ctx);
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||||
|
|
||||||
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
||||||
|
@ -857,16 +872,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
int n_cur = 0;
|
int n_cur = 0;
|
||||||
|
|
||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
|
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
// batch as much tasks as possible into the available context
|
// batch as much tasks as possible into the available context
|
||||||
// each task has 4 unique seuqnce ids - one for each ending
|
// each task has 4 unique sequence ids - one for each ending
|
||||||
// the common prefix is shared among the 4 sequences to save tokens
|
// the common prefix is shared among the 4 sequences to save tokens
|
||||||
// we extract logits only from the last common token and from all ending tokens of each sequence
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
||||||
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
||||||
auto & hs_cur = hs_data[i1];
|
auto & hs_cur = hs_data[i1];
|
||||||
|
int n_logits = 0;
|
||||||
|
|
||||||
const int s0 = 4*(i1 - i0);
|
const int s0 = 4*(i1 - i0);
|
||||||
if (s0 + 4 > max_seq) {
|
if (s0 + 4 > max_seq) {
|
||||||
|
@ -877,15 +893,20 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
|
const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
|
||||||
llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
|
// TODO: don't evaluate the last token of each sequence
|
||||||
|
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
|
||||||
|
const bool needs_logits = i < seq_tokens_size - 1;
|
||||||
|
llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||||
|
n_logits += needs_logits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hs_cur.i_batch = i_batch;
|
hs_cur.i_logits = i_logits;
|
||||||
i_batch += hs_cur.required_tokens;
|
i_logits += n_logits;
|
||||||
|
|
||||||
n_cur += hs_data[i1].required_tokens;
|
n_cur += hs_data[i1].required_tokens;
|
||||||
if (++i1 == hs_task_count) {
|
if (++i1 == hs_task_count) {
|
||||||
|
@ -911,12 +932,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
eval_pairs.clear();
|
eval_pairs.clear();
|
||||||
for (size_t i = i0; i < i1; ++i) {
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
auto & hs_cur = hs_data[i];
|
auto & hs_cur = hs_data[i];
|
||||||
size_t li = hs_cur.common_prefix;
|
size_t li = 1; // skip the last logit of the common prefix (computed separately below)
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
||||||
eval_pairs.emplace_back(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]);
|
eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]);
|
||||||
}
|
}
|
||||||
++li;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Then we do the actual calculation
|
// Then we do the actual calculation
|
||||||
|
@ -928,7 +948,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
for (size_t i = i0; i < i1; ++i) {
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
auto & hs_cur = hs_data[i];
|
auto & hs_cur = hs_data[i];
|
||||||
|
|
||||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
|
// get the logits of the last token of the common prefix
|
||||||
|
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
|
||||||
|
|
||||||
const auto first_probs = softmax(tok_logits);
|
const auto first_probs = softmax(tok_logits);
|
||||||
|
|
||||||
|
@ -978,7 +999,7 @@ struct winogrande_entry {
|
||||||
std::array<std::string, 2> choices;
|
std::array<std::string, 2> choices;
|
||||||
int answer;
|
int answer;
|
||||||
|
|
||||||
size_t i_batch;
|
size_t i_logits;
|
||||||
size_t common_prefix;
|
size_t common_prefix;
|
||||||
size_t required_tokens;
|
size_t required_tokens;
|
||||||
size_t n_base1; // number of tokens for context + choice 1
|
size_t n_base1; // number of tokens for context + choice 1
|
||||||
|
@ -1104,6 +1125,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||||
task.common_prefix++;
|
task.common_prefix++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: the last token of each of the sequences don't need to be evaluated
|
||||||
task.required_tokens = task.common_prefix +
|
task.required_tokens = task.common_prefix +
|
||||||
task.seq_tokens[0].size() - task.common_prefix +
|
task.seq_tokens[0].size() - task.common_prefix +
|
||||||
task.seq_tokens[1].size() - task.common_prefix;
|
task.seq_tokens[1].size() - task.common_prefix;
|
||||||
|
@ -1121,9 +1143,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||||
const int max_tasks_per_batch = 128;
|
const int max_tasks_per_batch = 128;
|
||||||
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, 2);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
|
// TODO: this could be made smaller; it's currently the worst-case size
|
||||||
std::vector<float> batch_logits(n_vocab*n_ctx);
|
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||||
|
|
||||||
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
||||||
|
@ -1137,11 +1160,12 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||||
int n_cur = 0;
|
int n_cur = 0;
|
||||||
|
|
||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_batch = 0;
|
size_t i_logits = 0;
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
||||||
|
int n_logits = 0;
|
||||||
const int s0 = 2*(i1 - i0);
|
const int s0 = 2*(i1 - i0);
|
||||||
if (s0 + 2 > max_seq) {
|
if (s0 + 2 > max_seq) {
|
||||||
break;
|
break;
|
||||||
|
@ -1151,15 +1175,18 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||||
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 2; ++s) {
|
for (int s = 0; s < 2; ++s) {
|
||||||
|
// TODO: end before the last token, no need to predict past the end of the sequences
|
||||||
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
||||||
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
||||||
|
n_logits += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data[i1].i_batch = i_batch;
|
data[i1].i_logits = i_logits;
|
||||||
i_batch += data[i1].required_tokens;
|
i_logits += n_logits;
|
||||||
|
|
||||||
n_cur += data[i1].required_tokens;
|
n_cur += data[i1].required_tokens;
|
||||||
if (++i1 == data.size()) {
|
if (++i1 == data.size()) {
|
||||||
|
@ -1190,15 +1217,16 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
||||||
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
||||||
size_t li = n_base1 - 1;
|
size_t li = n_base1 - task.common_prefix;
|
||||||
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
||||||
eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[0][j+1]);
|
eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
|
||||||
}
|
}
|
||||||
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
||||||
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
||||||
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
|
// FIXME: this uses the wrong first logits when not skipping the choice word
|
||||||
|
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
|
||||||
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
||||||
eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[1][j+1]);
|
eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
||||||
|
@ -1287,7 +1315,7 @@ struct multiple_choice_task {
|
||||||
}
|
}
|
||||||
|
|
||||||
// For evaluation
|
// For evaluation
|
||||||
size_t i_batch; // starting index in the llama_batch
|
size_t i_logits; // starting index of logits in the llama_batch
|
||||||
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
||||||
size_t required_tokens; // needed number of tokens to evaluate all answers
|
size_t required_tokens; // needed number of tokens to evaluate all answers
|
||||||
std::vector<std::vector<llama_token>> seq_tokens;
|
std::vector<std::vector<llama_token>> seq_tokens;
|
||||||
|
@ -1366,7 +1394,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
std::vector<uint32_t> task_pos(n_task);
|
std::vector<uint32_t> task_pos(n_task);
|
||||||
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
|
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
|
||||||
if (strstream.fail()) {
|
if (strstream.fail()) {
|
||||||
printf("%s: failed to raad task positions from prompt\n", __func__);
|
printf("%s: failed to read task positions from prompt\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1447,7 +1475,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int n_dot = n_task/100;
|
int n_dot = std::max((int) n_task/100, 1);
|
||||||
int i_task = 0;
|
int i_task = 0;
|
||||||
for (auto& task : tasks) {
|
for (auto& task : tasks) {
|
||||||
++i_task;
|
++i_task;
|
||||||
|
@ -1491,17 +1519,18 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
int n_cur = 0;
|
int n_cur = 0;
|
||||||
|
|
||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
|
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
// batch as much tasks as possible into the available context
|
// batch as much tasks as possible into the available context
|
||||||
// each task has 4 unique seuqnce ids - one for each ending
|
// each task has 4 unique sequence ids - one for each ending
|
||||||
// the common prefix is shared among the 4 sequences to save tokens
|
// the common prefix is shared among the 4 sequences to save tokens
|
||||||
// we extract logits only from the last common token and from all ending tokens of each sequence
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
||||||
int s0 = 0;
|
int s0 = 0;
|
||||||
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
|
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
|
||||||
auto& cur_task = tasks[i1];
|
auto& cur_task = tasks[i1];
|
||||||
|
int n_logits = 0;
|
||||||
|
|
||||||
int num_answers = cur_task.seq_tokens.size();
|
int num_answers = cur_task.seq_tokens.size();
|
||||||
if (s0 + num_answers > max_seq) {
|
if (s0 + num_answers > max_seq) {
|
||||||
|
@ -1518,17 +1547,22 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||||
for (size_t i = cur_task.common_prefix; i < cur_task.seq_tokens[s].size(); ++i) {
|
const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
|
||||||
llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, true);
|
// TODO: don't evaluate the last token of each sequence
|
||||||
|
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
|
||||||
|
const bool needs_logits = i < seq_tokens_size - 1;
|
||||||
|
llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||||
|
n_logits += needs_logits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s0 += num_answers;
|
s0 += num_answers;
|
||||||
|
|
||||||
cur_task.i_batch = i_batch;
|
cur_task.i_logits = i_logits;
|
||||||
i_batch += cur_task.required_tokens;
|
i_logits += n_logits;
|
||||||
|
|
||||||
n_cur += cur_task.required_tokens;
|
n_cur += cur_task.required_tokens;
|
||||||
if (++i1 == tasks.size()) {
|
if (++i1 == tasks.size()) {
|
||||||
|
@ -1554,12 +1588,11 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
eval_pairs.clear();
|
eval_pairs.clear();
|
||||||
for (size_t i = i0; i < i1; ++i) {
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
auto& cur_task = tasks[i];
|
auto& cur_task = tasks[i];
|
||||||
size_t li = cur_task.common_prefix;
|
size_t li = 1; // skip the last logit of the common prefix (computed separately below)
|
||||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||||
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
||||||
eval_pairs.emplace_back(cur_task.i_batch + li++, cur_task.seq_tokens[s][j + 1]);
|
eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]);
|
||||||
}
|
}
|
||||||
++li;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Then we do the actual calculation
|
// Then we do the actual calculation
|
||||||
|
@ -1578,7 +1611,8 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
//}
|
//}
|
||||||
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
|
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
|
||||||
|
|
||||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float));
|
// get the logits of the last token of the common prefix
|
||||||
|
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
|
||||||
|
|
||||||
const auto first_probs = softmax(tok_logits);
|
const auto first_probs = softmax(tok_logits);
|
||||||
|
|
||||||
|
@ -1730,6 +1764,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use llama_batch.logits instead of relying on logits_all == true
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -26,6 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||||
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },
|
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },
|
||||||
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
|
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
|
||||||
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
|
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
|
||||||
|
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
|
||||||
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
|
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
|
||||||
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
|
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
|
||||||
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
|
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
|
||||||
|
@ -87,13 +88,17 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
|
||||||
//
|
//
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
static void usage(const char * executable) {
|
static void usage(const char * executable) {
|
||||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
|
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
|
||||||
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
|
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
|
||||||
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
|
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
|
||||||
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
|
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
|
||||||
printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n");
|
printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n");
|
||||||
printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
||||||
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
||||||
|
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
|
||||||
|
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
||||||
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
|
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
||||||
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
|
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
|
||||||
printf("\nAllowed quantization types:\n");
|
printf("\nAllowed quantization types:\n");
|
||||||
for (auto & it : QUANT_OPTIONS) {
|
for (auto & it : QUANT_OPTIONS) {
|
||||||
|
@ -201,6 +206,43 @@ static ggml_type parse_ggml_type(const char * arg) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
|
||||||
|
const char* sep = strchr(data, '=');
|
||||||
|
if (sep == nullptr || sep - data >= 128) {
|
||||||
|
fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
llama_model_kv_override kvo;
|
||||||
|
std::strncpy(kvo.key, data, sep - data);
|
||||||
|
kvo.key[sep - data] = 0;
|
||||||
|
sep++;
|
||||||
|
if (strncmp(sep, "int:", 4) == 0) {
|
||||||
|
sep += 4;
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
||||||
|
kvo.int_value = std::atol(sep);
|
||||||
|
} else if (strncmp(sep, "float:", 6) == 0) {
|
||||||
|
sep += 6;
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
||||||
|
kvo.float_value = std::atof(sep);
|
||||||
|
} else if (strncmp(sep, "bool:", 5) == 0) {
|
||||||
|
sep += 5;
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
||||||
|
if (std::strcmp(sep, "true") == 0) {
|
||||||
|
kvo.bool_value = true;
|
||||||
|
} else if (std::strcmp(sep, "false") == 0) {
|
||||||
|
kvo.bool_value = false;
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
overrides.emplace_back(std::move(kvo));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
if (argc < 3) {
|
if (argc < 3) {
|
||||||
usage(argv[0]);
|
usage(argv[0]);
|
||||||
|
@ -211,6 +253,7 @@ int main(int argc, char ** argv) {
|
||||||
int arg_idx = 1;
|
int arg_idx = 1;
|
||||||
std::string imatrix_file;
|
std::string imatrix_file;
|
||||||
std::vector<std::string> included_weights, excluded_weights;
|
std::vector<std::string> included_weights, excluded_weights;
|
||||||
|
std::vector<llama_model_kv_override> kv_overrides;
|
||||||
|
|
||||||
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
||||||
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
|
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
|
||||||
|
@ -227,6 +270,10 @@ int main(int argc, char ** argv) {
|
||||||
} else {
|
} else {
|
||||||
usage(argv[0]);
|
usage(argv[0]);
|
||||||
}
|
}
|
||||||
|
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
|
||||||
|
if (arg_idx == argc-1 || !parse_kv_override(argv[++arg_idx], kv_overrides)) {
|
||||||
|
usage(argv[0]);
|
||||||
|
}
|
||||||
} else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
|
} else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
|
||||||
params.allow_requantize = true;
|
params.allow_requantize = true;
|
||||||
} else if (strcmp(argv[arg_idx], "--pure") == 0) {
|
} else if (strcmp(argv[arg_idx], "--pure") == 0) {
|
||||||
|
@ -267,6 +314,11 @@ int main(int argc, char ** argv) {
|
||||||
if (!imatrix_data.empty()) {
|
if (!imatrix_data.empty()) {
|
||||||
params.imatrix = &imatrix_data;
|
params.imatrix = &imatrix_data;
|
||||||
}
|
}
|
||||||
|
if (!kv_overrides.empty()) {
|
||||||
|
kv_overrides.emplace_back();
|
||||||
|
kv_overrides.back().key[0] = 0;
|
||||||
|
params.kv_overrides = &kv_overrides;
|
||||||
|
}
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
|
|
||||||
|
@ -288,8 +340,7 @@ int main(int argc, char ** argv) {
|
||||||
if (ftype_str == "COPY") {
|
if (ftype_str == "COPY") {
|
||||||
params.only_copy = true;
|
params.only_copy = true;
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
fname_out = argv[arg_idx];
|
fname_out = argv[arg_idx];
|
||||||
arg_idx++;
|
arg_idx++;
|
||||||
|
|
||||||
|
@ -320,10 +371,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
|
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
|
||||||
params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S ||
|
params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S ||
|
||||||
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) {
|
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S ||
|
||||||
fprintf(stderr, "\n===============================================================================================\n");
|
params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
|
||||||
fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
|
params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && imatrix_data.empty()) {
|
||||||
fprintf(stderr, "===============================================================================================\n\n\n");
|
fprintf(stderr, "\n==========================================================================================================\n");
|
||||||
|
fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
|
||||||
|
fprintf(stderr, "==========================================================================================================\n\n\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,7 @@ struct slot_params {
|
||||||
|
|
||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = -1; // RNG seed
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
|
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
|
@ -746,7 +747,8 @@ struct server_context {
|
||||||
{
|
{
|
||||||
const int32_t n_batch = llama_n_batch(ctx);
|
const int32_t n_batch = llama_n_batch(ctx);
|
||||||
|
|
||||||
batch = llama_batch_init(n_batch, 0, params.n_parallel);
|
// only a single seq_id per token is needed
|
||||||
|
batch = llama_batch_init(n_batch, 0, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.init();
|
metrics.init();
|
||||||
|
@ -846,6 +848,7 @@ struct server_context {
|
||||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||||
|
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||||
slot.params.seed = json_value(data, "seed", default_params.seed);
|
slot.params.seed = json_value(data, "seed", default_params.seed);
|
||||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
|
@ -1253,6 +1256,7 @@ struct server_context {
|
||||||
{"stop", slot.params.antiprompt},
|
{"stop", slot.params.antiprompt},
|
||||||
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
||||||
{"n_keep", slot.params.n_keep},
|
{"n_keep", slot.params.n_keep},
|
||||||
|
{"n_discard", slot.params.n_discard},
|
||||||
{"ignore_eos", ignore_eos},
|
{"ignore_eos", ignore_eos},
|
||||||
{"stream", slot.params.stream},
|
{"stream", slot.params.stream},
|
||||||
{"logit_bias", slot.sparams.logit_bias},
|
{"logit_bias", slot.sparams.logit_bias},
|
||||||
|
@ -1696,7 +1700,7 @@ struct server_context {
|
||||||
// Shift context
|
// Shift context
|
||||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||||
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
||||||
const int n_discard = n_left / 2;
|
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
||||||
|
|
||||||
LOG_INFO("slot context shift", {
|
LOG_INFO("slot context shift", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
|
|
|
@ -65,7 +65,6 @@ int main(int argc, char ** argv) {
|
||||||
llama_context * ctx_dft = NULL;
|
llama_context * ctx_dft = NULL;
|
||||||
|
|
||||||
// load the target model
|
// load the target model
|
||||||
params.logits_all = true;
|
|
||||||
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
// load the draft model
|
// load the draft model
|
||||||
|
|
|
@ -377,6 +377,20 @@ typedef struct {
|
||||||
} block_iq1_s;
|
} block_iq1_s;
|
||||||
static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
||||||
|
|
||||||
|
// 1.8125 bpw
|
||||||
|
typedef struct {
|
||||||
|
uint8_t qs[QK_K/8]; // grid index, low 8 bits
|
||||||
|
uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8)
|
||||||
|
uint8_t scales[QK_K/32]; // 4-bit block scales
|
||||||
|
} block_iq1_m;
|
||||||
|
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
|
||||||
|
|
||||||
|
// Used by IQ1_M quants
|
||||||
|
typedef union {
|
||||||
|
ggml_half f16;
|
||||||
|
uint16_t u16;
|
||||||
|
} iq1m_scale_t;
|
||||||
|
|
||||||
// Non-linear quants
|
// Non-linear quants
|
||||||
#define QK4_NL 32
|
#define QK4_NL 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -1050,6 +1064,7 @@ GGML_TABLE_END()
|
||||||
|
|
||||||
#define NGRID_IQ1S 2048
|
#define NGRID_IQ1S 2048
|
||||||
#define IQ1S_DELTA 0.125f
|
#define IQ1S_DELTA 0.125f
|
||||||
|
#define IQ1M_DELTA 0.125f
|
||||||
#if defined(GGML_COMMON_IMPL_C)
|
#if defined(GGML_COMMON_IMPL_C)
|
||||||
GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
|
GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
|
||||||
0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,
|
0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,
|
||||||
|
|
|
@ -615,6 +615,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -643,6 +644,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -2503,7 +2505,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2560,7 +2562,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
ggml_type a_type = a->type;
|
ggml_type a_type = a->type;
|
||||||
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
||||||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
|
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
|
||||||
a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
|
a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
|
||||||
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -373,7 +373,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
||||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -395,7 +395,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
||||||
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
||||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -416,7 +416,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
|
||||||
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
||||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -444,7 +444,7 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
||||||
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -470,7 +470,7 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
|
||||||
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -496,11 +496,42 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
||||||
y[j] = d * (q[j] + delta);
|
y[j] = d * (q[j] + delta);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq1_m * x = (const block_iq1_m *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
|
iq1m_scale_t scale;
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
|
||||||
|
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
|
||||||
|
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
|
||||||
|
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||||
|
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
|
||||||
|
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||||
|
grid32[0] &= 0x0f0f0f0f;
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
y[j] = d * (q[j] + delta);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
@ -658,6 +689,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
|
||||||
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
const int nb = (k + QK_K - 1) / QK_K;
|
const int nb = (k + QK_K - 1) / QK_K;
|
||||||
|
@ -724,6 +761,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||||
return dequantize_row_iq3_xxs_cuda;
|
return dequantize_row_iq3_xxs_cuda;
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
return dequantize_row_iq1_s_cuda;
|
return dequantize_row_iq1_s_cuda;
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
return dequantize_row_iq1_m_cuda;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
return dequantize_row_iq4_nl_cuda;
|
return dequantize_row_iq4_nl_cuda;
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
@ -769,6 +808,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
return dequantize_row_iq3_xxs_cuda;
|
return dequantize_row_iq3_xxs_cuda;
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
return dequantize_row_iq1_s_cuda;
|
return dequantize_row_iq1_s_cuda;
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
return dequantize_row_iq1_m_cuda;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
return dequantize_row_iq4_nl_cuda;
|
return dequantize_row_iq4_nl_cuda;
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
|
|
@ -282,6 +282,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(
|
||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void mul_mat_vec_iq1_m_q8_1_cuda(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||||
|
|
||||||
|
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||||
|
}
|
||||||
|
|
||||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(
|
static void mul_mat_vec_iq4_nl_q8_1_cuda(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||||
|
@ -373,6 +381,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||||
|
break;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -961,8 +961,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||||
return d * (sumi1 + sumi2);
|
return d * (sumi1 + sumi2);
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1001,13 +1000,11 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
||||||
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(ksigns64);
|
GGML_UNUSED(ksigns64);
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(ksigns64);
|
GGML_UNUSED(ksigns64);
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1049,13 +1046,11 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
||||||
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(ksigns64);
|
GGML_UNUSED(ksigns64);
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(ksigns64);
|
GGML_UNUSED(ksigns64);
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1085,12 +1080,10 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||||
const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
|
const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
|
||||||
return d * sumi;
|
return d * sumi;
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1119,12 +1112,10 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
||||||
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
|
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
|
||||||
return d * sumi;
|
return d * sumi;
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1159,8 +1150,50 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
||||||
const float m = d1q * __high2float(bq8_1[ib32].ds);
|
const float m = d1q * __high2float(bq8_1[ib32].ds);
|
||||||
return d * sumi + m * delta;
|
return d * sumi + m * delta;
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
||||||
|
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||||
|
#if QK_K == 256
|
||||||
|
const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
|
||||||
|
|
||||||
|
const int ib32 = iqs;
|
||||||
|
int sumi[2] = {0, 0};
|
||||||
|
float sumf[2] = {0.f, 0.f};
|
||||||
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
|
const int * q8 = (const int *)bq8_1[ib32].qs;
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
|
||||||
|
int grid0 = grid[0] & 0x0f0f0f0f;
|
||||||
|
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
|
||||||
|
sumi[l/2] = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi[l/2]));
|
||||||
|
const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
|
||||||
|
const int sumy = __dp4a(q8[2*l+1], 0x01010101, __dp4a(q8[2*l+0], 0x01010101, 0));
|
||||||
|
sumf[l/2] += delta*sumy;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
const int8_t * q8 = bq8_1[ib32].qs;
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
||||||
|
int sumy = 0;
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
sumi[l/2] += q8[j] * (grid[j] & 0xf) + q8[j+4] * (grid[j] >> 4);
|
||||||
|
sumy += q8[j] + q8[j+4];
|
||||||
|
}
|
||||||
|
const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
|
||||||
|
sumf[l/2] += delta*sumy;
|
||||||
|
q8 += 8;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
iq1m_scale_t scale;
|
||||||
|
const uint16_t * sc = (const uint16_t *)bq1->scales;
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
const float d = (float)scale.f16 * __low2float (bq8_1[ib32].ds);
|
||||||
|
return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1223,27 +1256,6 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
||||||
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
|
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
|
||||||
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
|
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
|
||||||
|
|
||||||
//// iqs is 0...7
|
|
||||||
//const int ib64 = iqs/2;
|
|
||||||
//const int il = iqs%2;
|
|
||||||
//const int32_t * q8_1 = (const int *)bq8_1[2*ib64+0].qs + 2*il;
|
|
||||||
//const int32_t * q8_2 = (const int *)bq8_1[2*ib64+1].qs + 2*il;
|
|
||||||
//const uint32_t * q4_1 = (const uint32_t *)bq4->qs + 8*ib64 + 2*il;
|
|
||||||
//const uint32_t * q4_2 = q4_1 + 4;
|
|
||||||
//const int8_t ls1 = (bq4->scales_l[ib64] & 0xf) | (((bq4->scales_h >> (4*ib64+0)) & 3) << 4);
|
|
||||||
//const int8_t ls2 = (bq4->scales_l[ib64] >> 4) | (((bq4->scales_h >> (4*ib64+2)) & 3) << 4);
|
|
||||||
//const float d1 = (float)bq4->d * (ls1 - 32) * __low2float(bq8_1[2*ib64+0].ds);
|
|
||||||
//const float d2 = (float)bq4->d * (ls2 - 32) * __low2float(bq8_1[2*ib64+1].ds);
|
|
||||||
//int v1, v2;
|
|
||||||
//int sumi1 = 0, sumi2 = 0;
|
|
||||||
//for (int j = 0; j < 2; ++j) {
|
|
||||||
// get_int_from_table_16(q4_1[j], values, v1, v2);
|
|
||||||
// sumi1 = __dp4a(v2, q8_1[j+4], __dp4a(v1, q8_1[j+0], sumi1));
|
|
||||||
// get_int_from_table_16(q4_2[j], values, v1, v2);
|
|
||||||
// sumi2 = __dp4a(v2, q8_2[j+4], __dp4a(v1, q8_2[j+0], sumi2));
|
|
||||||
//}
|
|
||||||
//return d1 * sumi1 + d2 * sumi2;
|
|
||||||
|
|
||||||
// iqs is 0...7
|
// iqs is 0...7
|
||||||
const int ib32 = iqs;
|
const int ib32 = iqs;
|
||||||
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||||
|
@ -1259,24 +1271,8 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
||||||
}
|
}
|
||||||
return d * (sumi1 + sumi2);
|
return d * (sumi1 + sumi2);
|
||||||
|
|
||||||
//// iqs is 0...15
|
|
||||||
//const int ib32 = iqs/2;
|
|
||||||
//const int il = iqs%2;
|
|
||||||
//const int32_t * q8 = (const int *)bq8_1[ib32].qs + 2*il;
|
|
||||||
//const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32 + 2*il;
|
|
||||||
//const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
|
|
||||||
//const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
|
|
||||||
//int v1, v2;
|
|
||||||
//int sumi1 = 0, sumi2 = 0;
|
|
||||||
//for (int j = 0; j < 2; ++j) {
|
|
||||||
// get_int_from_table_16(q4[j], values, v1, v2);
|
|
||||||
// sumi1 = __dp4a(v1, q8[j+0], sumi1);
|
|
||||||
// sumi2 = __dp4a(v2, q8[j+4], sumi2);
|
|
||||||
//}
|
|
||||||
//return d * (sumi1 + sumi2);
|
|
||||||
#else
|
#else
|
||||||
assert(false);
|
NO_DEVICE_CODE;
|
||||||
return 0.f;
|
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
|
return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
|
||||||
|
|
|
@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
struct ggml_tensor * dst = gf->nodes[i];
|
struct ggml_tensor * dst = gf->nodes[i];
|
||||||
GGML_ASSERT(dst->data != nullptr);
|
GGML_ASSERT(dst->data != nullptr);
|
||||||
|
|
||||||
|
if (ggml_is_empty(dst)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
|
|
41
ggml-metal.m
41
ggml-metal.m
|
@ -64,6 +64,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
||||||
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||||
|
@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
||||||
|
@ -114,6 +116,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
||||||
|
@ -134,6 +137,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||||
|
@ -154,6 +158,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
||||||
|
@ -490,6 +495,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||||
|
@ -517,6 +523,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
||||||
|
@ -540,6 +547,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
||||||
|
@ -560,6 +568,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
||||||
|
@ -580,6 +589,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
||||||
|
@ -837,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||||
struct ggml_tensor * dst = gf->nodes[i];
|
struct ggml_tensor * dst = gf->nodes[i];
|
||||||
|
|
||||||
|
if (ggml_is_empty(dst)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
|
@ -1421,6 +1435,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
||||||
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||||
|
@ -1575,6 +1590,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
nth1 = 16;
|
nth1 = 16;
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
{
|
||||||
|
nth0 = 4;
|
||||||
|
nth1 = 16;
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
||||||
|
} break;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
{
|
{
|
||||||
nth0 = 4;
|
nth0 = 4;
|
||||||
|
@ -1619,9 +1640,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
|
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||||
|
@ -1743,6 +1764,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
||||||
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||||
|
@ -1900,6 +1922,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
nth1 = 16;
|
nth1 = 16;
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
{
|
||||||
|
nth0 = 4;
|
||||||
|
nth1 = 16;
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
||||||
|
} break;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
{
|
{
|
||||||
nth0 = 4;
|
nth0 = 4;
|
||||||
|
@ -1960,9 +1988,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 ||
|
||||||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K ||
|
||||||
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
|
src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
||||||
|
@ -2024,6 +2052,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
||||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
||||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
||||||
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
|
||||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
||||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
||||||
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
||||||
|
|
216
ggml-metal.metal
216
ggml-metal.metal
|
@ -4456,6 +4456,104 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void kernel_mul_mv_iq1_m_f32_impl(
|
||||||
|
device const void * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
const int nb = ne00/QK_K;
|
||||||
|
const int r0 = tgpig.x;
|
||||||
|
const int r1 = tgpig.y;
|
||||||
|
const int im = tgpig.z;
|
||||||
|
|
||||||
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||||
|
const int ib_row = first_row * nb;
|
||||||
|
|
||||||
|
const uint i12 = im%ne12;
|
||||||
|
const uint i13 = im/ne12;
|
||||||
|
|
||||||
|
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||||
|
device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
|
||||||
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||||
|
|
||||||
|
float yl[32];
|
||||||
|
float sumf[N_DST]={0.f}, all_sum;
|
||||||
|
|
||||||
|
const int nb32 = nb * (QK_K / 32);
|
||||||
|
|
||||||
|
const int ix = tiisg;
|
||||||
|
|
||||||
|
device const float * y4 = y + 32 * ix;
|
||||||
|
|
||||||
|
iq1m_scale_t scale;
|
||||||
|
|
||||||
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||||
|
|
||||||
|
float4 sumy = {0.f};
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
||||||
|
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
||||||
|
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
||||||
|
yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ibl = ib32 / (QK_K / 32);
|
||||||
|
const int ib = ib32 % (QK_K / 32);
|
||||||
|
|
||||||
|
device const block_iq1_m * xr = x + ibl;
|
||||||
|
device const uint8_t * qs = xr->qs + 4 * ib;
|
||||||
|
device const uint8_t * qh = xr->qh + 2 * ib;
|
||||||
|
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
|
||||||
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||||
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
||||||
|
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
|
||||||
|
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
||||||
|
|
||||||
|
float2 sum = {0.f};
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
||||||
|
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
||||||
|
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
||||||
|
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
||||||
|
}
|
||||||
|
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||||
|
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||||
|
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
||||||
|
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
||||||
|
|
||||||
|
sc += nb*sizeof(block_iq1_m)/2;
|
||||||
|
qs += nb*sizeof(block_iq1_m);
|
||||||
|
qh += nb*sizeof(block_iq1_m);
|
||||||
|
}
|
||||||
|
|
||||||
|
y4 += 32 * 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
@ -4673,6 +4771,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
||||||
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
||||||
|
kernel void kernel_mul_mv_iq1_m_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
||||||
|
}
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
||||||
kernel void kernel_mul_mv_iq4_nl_f32(
|
kernel void kernel_mul_mv_iq4_nl_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
|
@ -5146,6 +5272,30 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4x4>
|
||||||
|
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
|
||||||
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||||
|
const int ib32 = il/2;
|
||||||
|
il = il%2;
|
||||||
|
iq1m_scale_t scale;
|
||||||
|
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
const float d = scale.f16;
|
||||||
|
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||||
|
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
||||||
|
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
||||||
|
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||||
|
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||||
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||||
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
|
||||||
|
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
|
||||||
|
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
|
||||||
|
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
||||||
|
@ -5730,6 +5880,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
|
||||||
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
|
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
||||||
|
@ -5778,6 +5929,7 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
|
||||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
||||||
|
@ -5838,6 +5990,7 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
|
||||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
|
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
||||||
|
@ -7005,6 +7158,69 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
||||||
sgitg);
|
sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
||||||
|
kernel void kernel_mul_mv_id_iq1_m_f32(
|
||||||
|
device const char * ids,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant uint64_t & nbi1,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
|
constant int & idx,
|
||||||
|
device const char * src00,
|
||||||
|
device const char * src01,
|
||||||
|
device const char * src02,
|
||||||
|
device const char * src03,
|
||||||
|
device const char * src04,
|
||||||
|
device const char * src05,
|
||||||
|
device const char * src06,
|
||||||
|
device const char * src07,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||||
|
|
||||||
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||||
|
|
||||||
|
tgpig.z = tgpig.z%(ne12*ne13);
|
||||||
|
|
||||||
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||||
|
|
||||||
|
kernel_mul_mv_iq1_m_f32_impl(
|
||||||
|
src0[id],
|
||||||
|
(device const float *) (src1 + bid*nb11),
|
||||||
|
dst + bid*ne0,
|
||||||
|
ne00,
|
||||||
|
ne01,
|
||||||
|
ne02,
|
||||||
|
ne10,
|
||||||
|
ne12,
|
||||||
|
ne0,
|
||||||
|
ne1,
|
||||||
|
r2,
|
||||||
|
r3,
|
||||||
|
tgpig,
|
||||||
|
tiisg,
|
||||||
|
sgitg);
|
||||||
|
}
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
||||||
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
|
|
|
@ -2234,6 +2234,11 @@ static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(gg
|
||||||
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
|
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
|
||||||
for (int i = 0; i < graph->n_nodes; ++i) {
|
for (int i = 0; i < graph->n_nodes; ++i) {
|
||||||
ggml_tensor * node = graph->nodes[i];
|
ggml_tensor * node = graph->nodes[i];
|
||||||
|
|
||||||
|
if (ggml_is_empty(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0);
|
ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0);
|
||||||
|
|
611
ggml-quants.c
611
ggml-quants.c
|
@ -3474,6 +3474,54 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int k) {
|
||||||
|
assert(k % QK_K == 0);
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
float delta[4];
|
||||||
|
uint16_t idx[4];
|
||||||
|
|
||||||
|
iq1m_scale_t scale;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
|
||||||
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
const float d = GGML_FP16_TO_FP32(scale.f16);
|
||||||
|
const uint8_t * qs = x[i].qs;
|
||||||
|
const uint8_t * qh = x[i].qh;
|
||||||
|
|
||||||
|
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||||
|
const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
|
||||||
|
const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
|
||||||
|
idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
|
||||||
|
idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
|
||||||
|
idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
|
||||||
|
idx[3] = qs[3] | ((qh[1] << 4) & 0x700);
|
||||||
|
delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||||
|
delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||||
|
delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||||
|
delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||||
|
for (int l = 0; l < 2; ++l) {
|
||||||
|
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
y[j] = dl1 * (grid[j] + delta[l]);
|
||||||
|
}
|
||||||
|
y += 8;
|
||||||
|
}
|
||||||
|
for (int l = 2; l < 4; ++l) {
|
||||||
|
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
y[j] = dl2 * (grid[j] + delta[l]);
|
||||||
|
}
|
||||||
|
y += 8;
|
||||||
|
}
|
||||||
|
qs += 4;
|
||||||
|
qh += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||||
|
|
||||||
void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
|
void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
|
||||||
|
@ -9695,6 +9743,206 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||||
|
assert(n % QK_K == 0);
|
||||||
|
assert(nrc == 1);
|
||||||
|
UNUSED(nrc);
|
||||||
|
UNUSED(bx);
|
||||||
|
UNUSED(by);
|
||||||
|
UNUSED(bs);
|
||||||
|
|
||||||
|
const block_iq1_m * restrict x = vx;
|
||||||
|
const block_q8_K * restrict y = vy;
|
||||||
|
|
||||||
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
|
iq1m_scale_t scale;
|
||||||
|
|
||||||
|
#if defined __ARM_NEON
|
||||||
|
|
||||||
|
const int32x4_t mask = vdupq_n_s32(0x7);
|
||||||
|
const int32x4_t mone = vdupq_n_s32(1);
|
||||||
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
|
|
||||||
|
ggml_int8x16x4_t deltas;
|
||||||
|
deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
|
||||||
|
deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
|
||||||
|
deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
|
||||||
|
deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
|
||||||
|
|
||||||
|
ggml_int8x16x4_t q1b;
|
||||||
|
ggml_int8x16x4_t q8b;
|
||||||
|
|
||||||
|
uint32_t aux32;
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)&aux32;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const int8_t * q8 = y[i].qs;
|
||||||
|
const uint8_t * qs = x[i].qs;
|
||||||
|
const uint8_t * qh = x[i].qh;
|
||||||
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
|
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
|
||||||
|
int32x4_t sumi1 = mzero;
|
||||||
|
int32x4_t sumi2 = mzero;
|
||||||
|
|
||||||
|
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||||
|
|
||||||
|
q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
|
||||||
|
vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
|
||||||
|
q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
|
||||||
|
vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
|
||||||
|
q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
|
||||||
|
vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
|
||||||
|
q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
|
||||||
|
vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
|
||||||
|
|
||||||
|
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
|
||||||
|
const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
|
||||||
|
const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
|
||||||
|
const int32x4_t p12 = vpaddq_s32(p1, p2);
|
||||||
|
|
||||||
|
const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
|
||||||
|
aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
|
||||||
|
|
||||||
|
const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
|
||||||
|
const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
|
||||||
|
const int32x4_t p34 = vpaddq_s32(p3, p4);
|
||||||
|
|
||||||
|
int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
|
||||||
|
scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
|
||||||
|
|
||||||
|
sumi1 = vmlaq_s32(sumi1, scales_4, p12);
|
||||||
|
sumi2 = vmlaq_s32(sumi2, scales_4, p34);
|
||||||
|
|
||||||
|
qs += 8; qh += 4;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = sumf;
|
||||||
|
|
||||||
|
#elif defined __AVX2__
|
||||||
|
|
||||||
|
const __m256i mask = _mm256_set1_epi16(0x7);
|
||||||
|
const __m256i mone = _mm256_set1_epi16(1);
|
||||||
|
|
||||||
|
__m256 accum1 = _mm256_setzero_ps();
|
||||||
|
__m256 accum2 = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const int8_t * q8 = y[i].qs;
|
||||||
|
const uint8_t * qs = x[i].qs;
|
||||||
|
const uint8_t * qh = x[i].qh;
|
||||||
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
|
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
|
||||||
|
__m256i sumi1 = _mm256_setzero_si256();
|
||||||
|
__m256i sumi2 = _mm256_setzero_si256();
|
||||||
|
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||||
|
const __m256i q1b_1 = _mm256_set_epi64x(
|
||||||
|
iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
|
||||||
|
iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
|
||||||
|
);
|
||||||
|
const __m256i q1b_2 = _mm256_set_epi64x(
|
||||||
|
iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
|
||||||
|
iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
|
||||||
|
);
|
||||||
|
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||||
|
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||||
|
|
||||||
|
const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
|
||||||
|
const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
|
||||||
|
|
||||||
|
const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||||
|
const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||||
|
qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||||
|
|
||||||
|
const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
|
||||||
|
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
|
||||||
|
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
|
||||||
|
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
|
||||||
|
scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
|
||||||
|
scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
|
||||||
|
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
|
||||||
|
const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
|
||||||
|
const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
|
||||||
|
const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
|
||||||
|
|
||||||
|
sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
|
||||||
|
sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
|
||||||
|
|
||||||
|
qs += 8; qh += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
|
||||||
|
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
|
||||||
|
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
int sum1[2], sum2[2], delta[4];
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
|
||||||
|
const int8_t * q8 = y[i].qs;
|
||||||
|
const uint8_t * qs = x[i].qs;
|
||||||
|
const uint8_t * qh = x[i].qh;
|
||||||
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
|
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
|
||||||
|
int sumi1 = 0, sumi2 = 0;
|
||||||
|
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||||
|
delta[0] = qh[0] & 0x08 ? -1 : 1;
|
||||||
|
delta[1] = qh[0] & 0x80 ? -1 : 1;
|
||||||
|
delta[2] = qh[1] & 0x08 ? -1 : 1;
|
||||||
|
delta[3] = qh[1] & 0x80 ? -1 : 1;
|
||||||
|
sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
|
||||||
|
int lsum1 = 0, lsum2 = 0;
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
lsum1 += q8[j] * grid[j];
|
||||||
|
lsum2 += q8[j];
|
||||||
|
}
|
||||||
|
q8 += 8;
|
||||||
|
sum1[l/2] += lsum1;
|
||||||
|
sum2[l/2] += lsum2*delta[l];
|
||||||
|
}
|
||||||
|
const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
|
||||||
|
const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
|
||||||
|
sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
|
||||||
|
sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
|
||||||
|
qs += 4;
|
||||||
|
qh += 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = sumf;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||||
assert(nrc == 1);
|
assert(nrc == 1);
|
||||||
UNUSED(nrc);
|
UNUSED(nrc);
|
||||||
|
@ -9938,17 +10186,17 @@ static iq2_entry_t iq2_data[4] = {
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline int iq2_data_index(enum ggml_type type) {
|
static inline int iq2_data_index(enum ggml_type type) {
|
||||||
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
|
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
|
||||||
return type == GGML_TYPE_IQ2_XXS ? 0 :
|
return type == GGML_TYPE_IQ2_XXS ? 0 :
|
||||||
type == GGML_TYPE_IQ2_XS ? 1 :
|
type == GGML_TYPE_IQ2_XS ? 1 :
|
||||||
type == GGML_TYPE_IQ1_S ? 2 : 3;
|
type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline int iq2_grid_size(enum ggml_type type) {
|
static inline int iq2_grid_size(enum ggml_type type) {
|
||||||
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
|
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
|
||||||
return type == GGML_TYPE_IQ2_XXS ? 256 :
|
return type == GGML_TYPE_IQ2_XXS ? 256 :
|
||||||
type == GGML_TYPE_IQ2_XS ? 512 :
|
type == GGML_TYPE_IQ2_XS ? 512 :
|
||||||
type == GGML_TYPE_IQ1_S ? NGRID_IQ1S : 1024;
|
type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int iq2_compare_func(const void * left, const void * right) {
|
static int iq2_compare_func(const void * left, const void * right) {
|
||||||
|
@ -10214,10 +10462,10 @@ void iq2xs_init_impl(enum ggml_type type) {
|
||||||
|
|
||||||
const int kmap_size = 43692;
|
const int kmap_size = 43692;
|
||||||
//const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
|
//const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
|
||||||
const int nwant = type == GGML_TYPE_IQ1_S ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
|
const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
|
||||||
const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
|
const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
|
||||||
type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 :
|
type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 :
|
||||||
type == GGML_TYPE_IQ1_S ? kgrid_1bit_2048 : kgrid_2bit_1024;
|
type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024;
|
||||||
uint64_t * kgrid_q2xs;
|
uint64_t * kgrid_q2xs;
|
||||||
int * kmap_q2xs;
|
int * kmap_q2xs;
|
||||||
uint16_t * kneighbors_q2xs;
|
uint16_t * kneighbors_q2xs;
|
||||||
|
@ -10314,7 +10562,7 @@ void iq2xs_init_impl(enum ggml_type type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void iq2xs_free_impl(enum ggml_type type) {
|
void iq2xs_free_impl(enum ggml_type type) {
|
||||||
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
|
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
|
||||||
const int gindex = iq2_data_index(type);
|
const int gindex = iq2_data_index(type);
|
||||||
if (iq2_data[gindex].grid) {
|
if (iq2_data[gindex].grid) {
|
||||||
free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
|
free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
|
||||||
|
@ -11520,7 +11768,16 @@ static int iq1_sort_helper(const void * left, const void * right) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#define IQ1S_BLOCK_SIZE 32
|
#define IQ1S_BLOCK_SIZE 32
|
||||||
static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
#define IQ1M_BLOCK_SIZE 16
|
||||||
|
static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights,
|
||||||
|
float * scales,
|
||||||
|
float * weight,
|
||||||
|
float * sumx,
|
||||||
|
float * sumw,
|
||||||
|
float * pairs,
|
||||||
|
int8_t * L,
|
||||||
|
uint16_t * index,
|
||||||
|
int8_t * shifts) {
|
||||||
|
|
||||||
const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
|
const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
|
||||||
|
|
||||||
|
@ -11534,22 +11791,17 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
||||||
GGML_ASSERT(n%QK_K == 0);
|
GGML_ASSERT(n%QK_K == 0);
|
||||||
|
|
||||||
|
block_iq1_s * y = vy;
|
||||||
|
|
||||||
const int nbl = n/QK_K;
|
const int nbl = n/QK_K;
|
||||||
|
|
||||||
block_iq1_s * y = vy;
|
const int block_size = IQ1S_BLOCK_SIZE;
|
||||||
|
|
||||||
const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA};
|
const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA};
|
||||||
const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA};
|
const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA};
|
||||||
|
|
||||||
float scales[QK_K/IQ1S_BLOCK_SIZE];
|
|
||||||
float weight[IQ1S_BLOCK_SIZE];
|
|
||||||
int8_t L[IQ1S_BLOCK_SIZE];
|
|
||||||
float sumx[IQ1S_BLOCK_SIZE+1];
|
|
||||||
float sumw[IQ1S_BLOCK_SIZE+1];
|
|
||||||
float pairs[2*IQ1S_BLOCK_SIZE];
|
|
||||||
int * idx = (int *)(pairs + 1);
|
int * idx = (int *)(pairs + 1);
|
||||||
uint16_t index[IQ1S_BLOCK_SIZE/8];
|
|
||||||
int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
|
|
||||||
|
|
||||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||||
|
|
||||||
|
@ -11564,15 +11816,15 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
||||||
float sigma2 = 2*sumx2/QK_K;
|
float sigma2 = 2*sumx2/QK_K;
|
||||||
|
|
||||||
for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
|
for (int ib = 0; ib < QK_K/block_size; ++ib) {
|
||||||
const float * xb = xbl + IQ1S_BLOCK_SIZE*ib;
|
const float * xb = xbl + block_size*ib;
|
||||||
const float * qw = quant_weights + QK_K*ibl + IQ1S_BLOCK_SIZE*ib;
|
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
|
||||||
for (int i = 0; i < IQ1S_BLOCK_SIZE; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
||||||
float max = fabsf(xb[0]);
|
float max = fabsf(xb[0]);
|
||||||
for (int i = 1; i < IQ1S_BLOCK_SIZE; ++i) max = MAX(max, fabsf(xb[i]));
|
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
|
||||||
if (!max) {
|
if (!max) {
|
||||||
scales[ib] = 0;
|
scales[ib] = 0;
|
||||||
memset(L, 1, IQ1S_BLOCK_SIZE);
|
memset(L, 1, block_size);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
|
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
|
||||||
|
@ -11581,14 +11833,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
// in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
|
// in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
|
||||||
// Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
|
// Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
|
||||||
// for each possible and score for each split.
|
// for each possible and score for each split.
|
||||||
for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
|
for (int j = 0; j < block_size; ++j) {
|
||||||
pairs[2*j] = xb[j];
|
pairs[2*j] = xb[j];
|
||||||
idx[2*j] = j;
|
idx[2*j] = j;
|
||||||
}
|
}
|
||||||
qsort(pairs, IQ1S_BLOCK_SIZE, 2*sizeof(float), iq1_sort_helper);
|
qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
|
||||||
{
|
{
|
||||||
sumx[0] = sumw[0] = 0;
|
sumx[0] = sumw[0] = 0;
|
||||||
for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
|
for (int j = 0; j < block_size; ++j) {
|
||||||
int i = idx[2*j];
|
int i = idx[2*j];
|
||||||
sumx[j+1] = sumx[j] + weight[i]*xb[i];
|
sumx[j+1] = sumx[j] + weight[i]*xb[i];
|
||||||
sumw[j+1] = sumw[j] + weight[i];
|
sumw[j+1] = sumw[j] + weight[i];
|
||||||
|
@ -11596,16 +11848,16 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
}
|
}
|
||||||
float best_score = 0, scale = max;
|
float best_score = 0, scale = max;
|
||||||
int besti1 = -1, besti2 = -1, best_shift = 0;
|
int besti1 = -1, besti2 = -1, best_shift = 0;
|
||||||
for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) {
|
for (int i1 = 0; i1 <= block_size; ++i1) {
|
||||||
for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) {
|
for (int i2 = i1; i2 <= block_size; ++i2) {
|
||||||
float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2])*x_p[2];
|
float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2];
|
||||||
float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2])*x_p[2]*x_p[2];
|
float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2];
|
||||||
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
||||||
scale = sumqx/sumq2; best_score = scale*sumqx;
|
scale = sumqx/sumq2; best_score = scale*sumqx;
|
||||||
besti1 = i1; besti2 = i2; best_shift = 1;
|
besti1 = i1; besti2 = i2; best_shift = 1;
|
||||||
}
|
}
|
||||||
sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2])*x_m[2];
|
sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2];
|
||||||
sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2])*x_m[2]*x_m[2];
|
sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2];
|
||||||
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
||||||
scale = sumqx/sumq2; best_score = scale*sumqx;
|
scale = sumqx/sumq2; best_score = scale*sumqx;
|
||||||
besti1 = i1; besti2 = i2; best_shift = -1;
|
besti1 = i1; besti2 = i2; best_shift = -1;
|
||||||
|
@ -11615,14 +11867,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
|
GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
|
||||||
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
|
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
|
||||||
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
|
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
|
||||||
for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2;
|
for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
|
||||||
if (scale < 0) {
|
if (scale < 0) {
|
||||||
for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j];
|
for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
|
||||||
scale = -scale; best_shift = -best_shift;
|
scale = -scale; best_shift = -best_shift;
|
||||||
}
|
}
|
||||||
bool all_on_grid = true;
|
bool all_on_grid = true;
|
||||||
const float * xx = best_shift == 1 ? x_p : x_m;
|
const float * xx = best_shift == 1 ? x_p : x_m;
|
||||||
for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
|
for (int k = 0; k < block_size/8; ++k) {
|
||||||
uint16_t u = 0;
|
uint16_t u = 0;
|
||||||
for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
|
for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
|
||||||
int grid_index = kmap_q2xs[u];
|
int grid_index = kmap_q2xs[u];
|
||||||
|
@ -11636,7 +11888,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
}
|
}
|
||||||
if (!all_on_grid) {
|
if (!all_on_grid) {
|
||||||
float sumqx = 0, sumq2 = 0;
|
float sumqx = 0, sumq2 = 0;
|
||||||
for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
|
for (int k = 0; k < block_size/8; ++k) {
|
||||||
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
|
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
|
||||||
for (int j = 0; j < 8; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
float w = weight[8*k + j];
|
float w = weight[8*k + j];
|
||||||
|
@ -11648,8 +11900,8 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
|
if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
|
||||||
}
|
}
|
||||||
uint16_t h = 0;
|
uint16_t h = 0;
|
||||||
for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
|
for (int k = 0; k < block_size/8; ++k) {
|
||||||
y[ibl].qs[(IQ1S_BLOCK_SIZE/8)*ib + k] = index[k] & 255;
|
y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255;
|
||||||
h |= (index[k] >> 8) << 3*k;
|
h |= (index[k] >> 8) << 3*k;
|
||||||
}
|
}
|
||||||
y[ibl].qh[ib] = h;
|
y[ibl].qh[ib] = h;
|
||||||
|
@ -11660,14 +11912,13 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!max_scale) {
|
if (!max_scale) {
|
||||||
memset(y[ibl].qs, 0, QK_K/8);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
float d = max_scale/15;
|
float d = max_scale/15;
|
||||||
y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.085f is another fudge factor. Don't ask me why it is needed.
|
y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed.
|
||||||
float id = 1/d;
|
float id = 1/d;
|
||||||
for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
|
for (int ib = 0; ib < QK_K/block_size; ++ib) {
|
||||||
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
||||||
l = MAX(0, MIN(7, l));
|
l = MAX(0, MIN(7, l));
|
||||||
if (shifts[ib] == -1) l |= 8;
|
if (shifts[ib] == -1) l |= 8;
|
||||||
|
@ -11678,16 +11929,292 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
||||||
|
|
||||||
size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
|
size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
|
||||||
GGML_ASSERT(n_per_row%QK_K == 0);
|
GGML_ASSERT(n_per_row%QK_K == 0);
|
||||||
|
float scales[QK_K/IQ1S_BLOCK_SIZE];
|
||||||
|
float weight[IQ1S_BLOCK_SIZE];
|
||||||
|
int8_t L[IQ1S_BLOCK_SIZE];
|
||||||
|
float sumx[IQ1S_BLOCK_SIZE+1];
|
||||||
|
float sumw[IQ1S_BLOCK_SIZE+1];
|
||||||
|
float pairs[2*IQ1S_BLOCK_SIZE];
|
||||||
|
uint16_t index[IQ1S_BLOCK_SIZE/8];
|
||||||
|
int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
|
||||||
int nblock = n_per_row/QK_K;
|
int nblock = n_per_row/QK_K;
|
||||||
char * qrow = (char *)dst;
|
char * qrow = (char *)dst;
|
||||||
for (int row = 0; row < nrow; ++row) {
|
for (int row = 0; row < nrow; ++row) {
|
||||||
quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
|
quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts);
|
||||||
src += n_per_row;
|
src += n_per_row;
|
||||||
qrow += nblock*sizeof(block_iq1_s);
|
qrow += nblock*sizeof(block_iq1_s);
|
||||||
}
|
}
|
||||||
return nrow * nblock * sizeof(block_iq1_s);
|
return nrow * nblock * sizeof(block_iq1_s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights,
|
||||||
|
float * scales,
|
||||||
|
float * weight,
|
||||||
|
float * pairs,
|
||||||
|
int8_t * L,
|
||||||
|
uint16_t * index,
|
||||||
|
int8_t * shifts) {
|
||||||
|
|
||||||
|
const int gindex = iq2_data_index(GGML_TYPE_IQ1_M);
|
||||||
|
|
||||||
|
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
||||||
|
const int * kmap_q2xs = iq2_data[gindex].map;
|
||||||
|
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
|
||||||
|
|
||||||
|
//GGML_ASSERT(quant_weights && "missing quantization weights");
|
||||||
|
GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
|
||||||
|
GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
|
||||||
|
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
||||||
|
GGML_ASSERT(n%QK_K == 0);
|
||||||
|
|
||||||
|
block_iq1_m * y = vy;
|
||||||
|
|
||||||
|
const int nbl = n/QK_K;
|
||||||
|
|
||||||
|
const int block_size = IQ1M_BLOCK_SIZE;
|
||||||
|
|
||||||
|
const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA};
|
||||||
|
const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
|
||||||
|
const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
|
||||||
|
|
||||||
|
int * idx = (int *)(pairs + 1);
|
||||||
|
|
||||||
|
float sumqx[4], sumq2[4];
|
||||||
|
|
||||||
|
iq1m_scale_t s;
|
||||||
|
const float * xx;
|
||||||
|
|
||||||
|
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||||
|
|
||||||
|
//y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
||||||
|
memset(y[ibl].qs, 0, QK_K/8);
|
||||||
|
memset(y[ibl].qh, 0, QK_K/16);
|
||||||
|
memset(y[ibl].scales, 0, QK_K/32);
|
||||||
|
|
||||||
|
float max_scale = 0;
|
||||||
|
|
||||||
|
const float * xbl = x + QK_K*ibl;
|
||||||
|
float sumx2 = 0;
|
||||||
|
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
||||||
|
float sigma2 = 2*sumx2/QK_K;
|
||||||
|
|
||||||
|
for (int ib = 0; ib < QK_K/block_size; ++ib) {
|
||||||
|
const float * xb = xbl + block_size*ib;
|
||||||
|
if (quant_weights) {
|
||||||
|
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
|
||||||
|
for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
|
||||||
|
}
|
||||||
|
float max = fabsf(xb[0]);
|
||||||
|
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
|
||||||
|
if (!max) {
|
||||||
|
scales[ib] = 0;
|
||||||
|
memset(L, 1, block_size);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
|
||||||
|
// With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
|
||||||
|
// boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
|
||||||
|
// in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
|
||||||
|
// Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
|
||||||
|
// for each possible and score for each split.
|
||||||
|
for (int j = 0; j < block_size; ++j) {
|
||||||
|
pairs[2*j] = xb[j];
|
||||||
|
idx[2*j] = j;
|
||||||
|
}
|
||||||
|
qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
|
||||||
|
float best_score = 0, scale = max;
|
||||||
|
int besti1 = -1, besti2 = -1, best_k = -1;
|
||||||
|
// 0: +, +
|
||||||
|
// 1: +, -
|
||||||
|
// 2: -, +
|
||||||
|
// 3: -, -
|
||||||
|
for (int i1 = 0; i1 <= block_size; ++i1) {
|
||||||
|
for (int i2 = i1; i2 <= block_size; ++i2) {
|
||||||
|
memset(sumqx, 0, 4*sizeof(float));
|
||||||
|
memset(sumq2, 0, 4*sizeof(float));
|
||||||
|
for (int j = 0; j < i1; ++j) {
|
||||||
|
int i = idx[2*j];
|
||||||
|
if (i < block_size/2) {
|
||||||
|
sumqx[0] += weight[i]*x_p[0]*xb[i];
|
||||||
|
sumqx[1] += weight[i]*x_p[0]*xb[i];
|
||||||
|
sumqx[2] += weight[i]*x_m[0]*xb[i];
|
||||||
|
sumqx[3] += weight[i]*x_m[0]*xb[i];
|
||||||
|
sumq2[0] += weight[i]*x_p[0]*x_p[0];
|
||||||
|
sumq2[1] += weight[i]*x_p[0]*x_p[0];
|
||||||
|
sumq2[2] += weight[i]*x_m[0]*x_m[0];
|
||||||
|
sumq2[3] += weight[i]*x_m[0]*x_m[0];
|
||||||
|
} else {
|
||||||
|
sumqx[0] += weight[i]*x_p[0]*xb[i];
|
||||||
|
sumqx[2] += weight[i]*x_p[0]*xb[i];
|
||||||
|
sumqx[1] += weight[i]*x_m[0]*xb[i];
|
||||||
|
sumqx[3] += weight[i]*x_m[0]*xb[i];
|
||||||
|
sumq2[0] += weight[i]*x_p[0]*x_p[0];
|
||||||
|
sumq2[2] += weight[i]*x_p[0]*x_p[0];
|
||||||
|
sumq2[1] += weight[i]*x_m[0]*x_m[0];
|
||||||
|
sumq2[3] += weight[i]*x_m[0]*x_m[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = i1; j < i2; ++j) {
|
||||||
|
int i = idx[2*j];
|
||||||
|
if (i < block_size/2) {
|
||||||
|
sumqx[0] += weight[i]*x_p[1]*xb[i];
|
||||||
|
sumqx[1] += weight[i]*x_p[1]*xb[i];
|
||||||
|
sumqx[2] += weight[i]*x_m[1]*xb[i];
|
||||||
|
sumqx[3] += weight[i]*x_m[1]*xb[i];
|
||||||
|
sumq2[0] += weight[i]*x_p[1]*x_p[1];
|
||||||
|
sumq2[1] += weight[i]*x_p[1]*x_p[1];
|
||||||
|
sumq2[2] += weight[i]*x_m[1]*x_m[1];
|
||||||
|
sumq2[3] += weight[i]*x_m[1]*x_m[1];
|
||||||
|
} else {
|
||||||
|
sumqx[0] += weight[i]*x_p[1]*xb[i];
|
||||||
|
sumqx[2] += weight[i]*x_p[1]*xb[i];
|
||||||
|
sumqx[1] += weight[i]*x_m[1]*xb[i];
|
||||||
|
sumqx[3] += weight[i]*x_m[1]*xb[i];
|
||||||
|
sumq2[0] += weight[i]*x_p[1]*x_p[1];
|
||||||
|
sumq2[2] += weight[i]*x_p[1]*x_p[1];
|
||||||
|
sumq2[1] += weight[i]*x_m[1]*x_m[1];
|
||||||
|
sumq2[3] += weight[i]*x_m[1]*x_m[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = i2; j < block_size; ++j) {
|
||||||
|
int i = idx[2*j];
|
||||||
|
if (i < block_size/2) {
|
||||||
|
sumqx[0] += weight[i]*x_p[2]*xb[i];
|
||||||
|
sumqx[1] += weight[i]*x_p[2]*xb[i];
|
||||||
|
sumqx[2] += weight[i]*x_m[2]*xb[i];
|
||||||
|
sumqx[3] += weight[i]*x_m[2]*xb[i];
|
||||||
|
sumq2[0] += weight[i]*x_p[2]*x_p[2];
|
||||||
|
sumq2[1] += weight[i]*x_p[2]*x_p[2];
|
||||||
|
sumq2[2] += weight[i]*x_m[2]*x_m[2];
|
||||||
|
sumq2[3] += weight[i]*x_m[2]*x_m[2];
|
||||||
|
} else {
|
||||||
|
sumqx[0] += weight[i]*x_p[2]*xb[i];
|
||||||
|
sumqx[2] += weight[i]*x_p[2]*xb[i];
|
||||||
|
sumqx[1] += weight[i]*x_m[2]*xb[i];
|
||||||
|
sumqx[3] += weight[i]*x_m[2]*xb[i];
|
||||||
|
sumq2[0] += weight[i]*x_p[2]*x_p[2];
|
||||||
|
sumq2[2] += weight[i]*x_p[2]*x_p[2];
|
||||||
|
sumq2[1] += weight[i]*x_m[2]*x_m[2];
|
||||||
|
sumq2[3] += weight[i]*x_m[2]*x_m[2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int k = 0; k < 4; ++k) {
|
||||||
|
if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
|
||||||
|
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
|
||||||
|
besti1 = i1; besti2 = i2; best_k = k;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0);
|
||||||
|
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
|
||||||
|
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
|
||||||
|
for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
|
||||||
|
if (scale < 0) {
|
||||||
|
for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
|
||||||
|
scale = -scale;
|
||||||
|
best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0;
|
||||||
|
}
|
||||||
|
bool all_on_grid = true;
|
||||||
|
for (int k = 0; k < block_size/8; ++k) {
|
||||||
|
if (k == 0) xx = best_k < 2 ? x_p : x_m;
|
||||||
|
else xx = best_k%2 == 0 ? x_p : x_m;
|
||||||
|
uint16_t u = 0;
|
||||||
|
for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
|
||||||
|
int grid_index = kmap_q2xs[u];
|
||||||
|
if (grid_index < 0) {
|
||||||
|
all_on_grid = false;
|
||||||
|
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
||||||
|
grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
|
||||||
|
GGML_ASSERT(grid_index >= 0);
|
||||||
|
}
|
||||||
|
index[k] = grid_index;
|
||||||
|
}
|
||||||
|
if (!all_on_grid) {
|
||||||
|
float sumqx_f = 0, sumq2_f = 0;
|
||||||
|
for (int k = 0; k < block_size/8; ++k) {
|
||||||
|
if (k == 0) xx = best_k < 2 ? x_p : x_m;
|
||||||
|
else xx = best_k%2 == 0 ? x_p : x_m;
|
||||||
|
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
float w = weight[8*k + j];
|
||||||
|
float q = xx[(pg[j] - 1)/2];
|
||||||
|
sumqx_f += w*q*xb[8*k+j];
|
||||||
|
sumq2_f += w*q*q;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
|
||||||
|
}
|
||||||
|
y[ibl].qs[2*ib + 0] = index[0] & 255;
|
||||||
|
y[ibl].qs[2*ib + 1] = index[1] & 255;
|
||||||
|
y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4);
|
||||||
|
GGML_ASSERT(scale >= 0);
|
||||||
|
scales[ib] = scale;
|
||||||
|
shifts[ib] = best_k;
|
||||||
|
max_scale = MAX(max_scale, scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!max_scale) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t * sc = (uint16_t *)y[ibl].scales;
|
||||||
|
float d = max_scale/15;
|
||||||
|
float id = 1/d;
|
||||||
|
float sumqx_f = 0, sumq2_f = 0;
|
||||||
|
for (int ib = 0; ib < QK_K/block_size; ++ib) {
|
||||||
|
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
|
||||||
|
l = MAX(0, MIN(7, l));
|
||||||
|
sc[ib/4] |= (l << 3*(ib%4));
|
||||||
|
y[ibl].qh[ib] |= masks[shifts[ib]];
|
||||||
|
const float * xb = xbl + block_size*ib;
|
||||||
|
if (quant_weights) {
|
||||||
|
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
|
||||||
|
for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
|
||||||
|
}
|
||||||
|
for (int k = 0; k < block_size/8; ++k) {
|
||||||
|
if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m;
|
||||||
|
else xx = shifts[ib]%2 == 0 ? x_p : x_m;
|
||||||
|
const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700));
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
float w = weight[8*k + j];
|
||||||
|
float q = xx[(pg[j] - 1)/2]*(2*l+1);
|
||||||
|
sumqx_f += w*q*xb[8*k+j];
|
||||||
|
sumq2_f += w*q*q;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
|
||||||
|
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
|
||||||
|
sc[0] |= ((s.u16 & 0x000f) << 12);
|
||||||
|
sc[1] |= ((s.u16 & 0x00f0) << 8);
|
||||||
|
sc[2] |= ((s.u16 & 0x0f00) << 4);
|
||||||
|
sc[3] |= ((s.u16 & 0xf000) << 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
|
||||||
|
GGML_ASSERT(n_per_row%QK_K == 0);
|
||||||
|
float scales[QK_K/IQ1M_BLOCK_SIZE];
|
||||||
|
float weight[IQ1M_BLOCK_SIZE];
|
||||||
|
int8_t L[IQ1M_BLOCK_SIZE];
|
||||||
|
float pairs[2*IQ1M_BLOCK_SIZE];
|
||||||
|
uint16_t index[IQ1M_BLOCK_SIZE/8];
|
||||||
|
int8_t shifts[QK_K/IQ1M_BLOCK_SIZE];
|
||||||
|
int nblock = n_per_row/QK_K;
|
||||||
|
char * qrow = (char *)dst;
|
||||||
|
for (int row = 0; row < nrow; ++row) {
|
||||||
|
quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts);
|
||||||
|
src += n_per_row;
|
||||||
|
qrow += nblock*sizeof(block_iq1_m);
|
||||||
|
}
|
||||||
|
return nrow * nblock * sizeof(block_iq1_m);
|
||||||
|
}
|
||||||
|
|
||||||
// ============================ 4-bit non-linear quants
|
// ============================ 4-bit non-linear quants
|
||||||
|
|
||||||
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
||||||
|
|
|
@ -72,6 +72,7 @@ void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_
|
||||||
void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
|
void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
||||||
|
@ -94,6 +95,7 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||||
void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
@ -104,6 +106,7 @@ size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
|
||||||
size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
||||||
size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
||||||
size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
||||||
|
size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
||||||
size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
||||||
size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
||||||
size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix);
|
||||||
|
|
|
@ -16973,7 +16973,7 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
|
||||||
params.ith = 0;
|
params.ith = 0;
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
|
|
|
@ -5566,7 +5566,7 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
48
ggml.c
48
ggml.c
|
@ -794,6 +794,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
},
|
},
|
||||||
|
[GGML_TYPE_IQ1_M] = {
|
||||||
|
.type_name = "iq1_m",
|
||||||
|
.blck_size = QK_K,
|
||||||
|
.type_size = sizeof(block_iq1_m),
|
||||||
|
.is_quantized = true,
|
||||||
|
.to_float = (ggml_to_float_t) dequantize_row_iq1_m,
|
||||||
|
.from_float = NULL,
|
||||||
|
.from_float_reference = NULL,
|
||||||
|
.vec_dot = ggml_vec_dot_iq1_m_q8_K,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
|
.nrows = 1,
|
||||||
|
},
|
||||||
[GGML_TYPE_IQ4_NL] = {
|
[GGML_TYPE_IQ4_NL] = {
|
||||||
.type_name = "iq4_nl",
|
.type_name = "iq4_nl",
|
||||||
.blck_size = QK4_NL,
|
.blck_size = QK4_NL,
|
||||||
|
@ -2539,6 +2551,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||||
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
|
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
|
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
|
||||||
|
case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
||||||
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
||||||
|
@ -2594,6 +2607,16 @@ static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
||||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_CALL bool ggml_is_empty(const struct ggml_tensor * tensor) {
|
||||||
|
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||||
|
if (tensor->ne[i] == 0) {
|
||||||
|
// empty if any dimension has no elements
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
|
@ -2608,7 +2631,7 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
|
||||||
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
return
|
return ggml_is_empty(t0) ? ggml_is_empty(t1) :
|
||||||
(t1->ne[0]%t0->ne[0] == 0) &&
|
(t1->ne[0]%t0->ne[0] == 0) &&
|
||||||
(t1->ne[1]%t0->ne[1] == 0) &&
|
(t1->ne[1]%t0->ne[1] == 0) &&
|
||||||
(t1->ne[2]%t0->ne[2] == 0) &&
|
(t1->ne[2]%t0->ne[2] == 0) &&
|
||||||
|
@ -8135,6 +8158,7 @@ static void ggml_compute_forward_add(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -8417,6 +8441,7 @@ static void ggml_compute_forward_add1(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -8544,6 +8569,7 @@ static void ggml_compute_forward_acc(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -11447,6 +11473,7 @@ static void ggml_compute_forward_out_prod(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -11638,6 +11665,7 @@ static void ggml_compute_forward_set(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -11861,6 +11889,7 @@ static void ggml_compute_forward_get_rows(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -12564,6 +12593,7 @@ static void ggml_compute_forward_alibi(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -12652,6 +12682,7 @@ static void ggml_compute_forward_clamp(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
|
@ -16093,7 +16124,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
||||||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||||
GGML_ASSERT(params);
|
GGML_ASSERT(params);
|
||||||
|
|
||||||
if (tensor->op == GGML_OP_NONE) {
|
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17962,6 +17993,12 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
|
||||||
static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
|
static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
|
||||||
int n_tasks = 0;
|
int n_tasks = 0;
|
||||||
|
|
||||||
|
if (ggml_is_empty(node)) {
|
||||||
|
// no need to multi-thread a no-op
|
||||||
|
n_tasks = 1;
|
||||||
|
return n_tasks;
|
||||||
|
}
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
|
@ -20306,7 +20343,8 @@ void ggml_quantize_init(enum ggml_type type) {
|
||||||
case GGML_TYPE_IQ2_XXS:
|
case GGML_TYPE_IQ2_XXS:
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
|
case GGML_TYPE_IQ1_S:
|
||||||
|
case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break;
|
||||||
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
|
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
|
||||||
case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
|
case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
|
||||||
default: // nothing
|
default: // nothing
|
||||||
|
@ -20331,7 +20369,8 @@ bool ggml_quantize_requires_imatrix(enum ggml_type type) {
|
||||||
return
|
return
|
||||||
type == GGML_TYPE_IQ2_XXS ||
|
type == GGML_TYPE_IQ2_XXS ||
|
||||||
type == GGML_TYPE_IQ2_XS ||
|
type == GGML_TYPE_IQ2_XS ||
|
||||||
type == GGML_TYPE_IQ1_S;
|
type == GGML_TYPE_IQ1_S;// ||
|
||||||
|
//type == GGML_TYPE_IQ1_M;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ggml_quantize_chunk(
|
size_t ggml_quantize_chunk(
|
||||||
|
@ -20375,6 +20414,7 @@ size_t ggml_quantize_chunk(
|
||||||
case GGML_TYPE_IQ3_S: result = quantize_iq3_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_IQ3_S: result = quantize_iq3_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
|
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
|
|
3
ggml.h
3
ggml.h
|
@ -369,6 +369,7 @@ extern "C" {
|
||||||
GGML_TYPE_I32 = 26,
|
GGML_TYPE_I32 = 26,
|
||||||
GGML_TYPE_I64 = 27,
|
GGML_TYPE_I64 = 27,
|
||||||
GGML_TYPE_F64 = 28,
|
GGML_TYPE_F64 = 28,
|
||||||
|
GGML_TYPE_IQ1_M = 29,
|
||||||
GGML_TYPE_COUNT,
|
GGML_TYPE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -408,6 +409,7 @@ extern "C" {
|
||||||
GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
|
GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
|
||||||
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
|
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
|
||||||
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
||||||
|
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
||||||
};
|
};
|
||||||
|
|
||||||
// available tensor operations:
|
// available tensor operations:
|
||||||
|
@ -748,6 +750,7 @@ extern "C" {
|
||||||
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||||
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
||||||
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||||
|
GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
|
||||||
|
|
|
@ -728,6 +728,7 @@ class GGMLQuantizationType(IntEnum):
|
||||||
I32 = 26
|
I32 = 26
|
||||||
I64 = 27
|
I64 = 27
|
||||||
F64 = 28
|
F64 = 28
|
||||||
|
IQ1_M = 29
|
||||||
|
|
||||||
|
|
||||||
class GGUFEndian(IntEnum):
|
class GGUFEndian(IntEnum):
|
||||||
|
|
26
llama.h
26
llama.h
|
@ -39,7 +39,7 @@
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 4
|
#define LLAMA_SESSION_VERSION 5
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -117,6 +117,7 @@ extern "C" {
|
||||||
LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
||||||
|
|
||||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||||
};
|
};
|
||||||
|
@ -284,6 +285,7 @@ extern "C" {
|
||||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||||
bool pure; // quantize all tensors to the default type
|
bool pure; // quantize all tensors to the default type
|
||||||
void * imatrix; // pointer to importance matrix data
|
void * imatrix; // pointer to importance matrix data
|
||||||
|
void * kv_overrides; // pointer to vector containing overrides
|
||||||
} llama_model_quantize_params;
|
} llama_model_quantize_params;
|
||||||
|
|
||||||
// grammar types
|
// grammar types
|
||||||
|
@ -676,23 +678,29 @@ extern "C" {
|
||||||
LLAMA_API void llama_synchronize(struct llama_context * ctx);
|
LLAMA_API void llama_synchronize(struct llama_context * ctx);
|
||||||
|
|
||||||
// Token logits obtained from the last call to llama_decode()
|
// Token logits obtained from the last call to llama_decode()
|
||||||
// The logits for the last token are stored in the last row
|
// The logits for which llama_batch.logits[i] != 0 are stored contiguously
|
||||||
// Logits for which llama_batch.logits[i] == 0 are undefined
|
// in the order they have appeared in the batch.
|
||||||
// Rows: n_tokens provided with llama_batch
|
// Rows: number of tokens for which llama_batch.logits[i] != 0
|
||||||
// Cols: n_vocab
|
// Cols: n_vocab
|
||||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||||
|
|
||||||
// Logits for the ith token. Equivalent to:
|
// Logits for the ith token. Equivalent to:
|
||||||
// llama_get_logits(ctx) + i*n_vocab
|
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
|
||||||
|
// returns NULL for invalid ids.
|
||||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
// Get all output token embeddings
|
// Get all output token embeddings.
|
||||||
// shape: [n_tokens*n_embd] (1-dimensional)
|
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||||
|
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
||||||
|
// in the order they have appeared in the batch.
|
||||||
|
// shape: [n_outputs*n_embd]
|
||||||
|
// Otherwise, returns NULL.
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Get the embeddings for the ith token
|
// Get the embeddings for the ith token. Equivalent to:
|
||||||
// llama_get_embeddings(ctx) + i*n_embd
|
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_embd] (1-dimensional)
|
||||||
|
// returns NULL for invalid ids.
|
||||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
// Get the embeddings for a sequence id
|
// Get the embeddings for a sequence id
|
||||||
|
|
|
@ -1960,7 +1960,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||||
GGML_TYPE_Q6_K,
|
GGML_TYPE_Q6_K,
|
||||||
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
|
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
|
||||||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S,
|
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
|
||||||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
1651
unicode-data.cpp
Normal file
1651
unicode-data.cpp
Normal file
File diff suppressed because it is too large
Load diff
16
unicode-data.h
Normal file
16
unicode-data.h
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_digit;
|
||||||
|
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
|
||||||
|
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
|
||||||
|
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
|
||||||
|
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
|
||||||
|
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
|
||||||
|
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
|
||||||
|
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
|
||||||
|
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
|
1411
unicode.cpp
1411
unicode.cpp
File diff suppressed because it is too large
Load diff
|
@ -24,3 +24,5 @@ int unicode_cpt_type(const std::string & utf8);
|
||||||
std::string unicode_byte_to_utf8(uint8_t byte);
|
std::string unicode_byte_to_utf8(uint8_t byte);
|
||||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||||
|
|
||||||
|
// simple tolower that only implements one-to-one mapping, not one-to-many
|
||||||
|
char32_t unicode_tolower(char32_t cp);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue