Merge branch 'master' into hp/download-model-from-hf

# Conflicts:
#	common/common.cpp
This commit is contained in:
Pierrick HYMBERT 2024-03-16 16:57:24 +01:00
commit 1430e895fc
5 changed files with 538 additions and 120 deletions

22
.github/workflows/close-issue.yml vendored Normal file
View file

@ -0,0 +1,22 @@
name: Close inactive issues
on:
schedule:
- cron: "42 0 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v5
with:
days-before-issue-stale: 30
days-before-issue-close: 14
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}

View file

@ -134,6 +134,7 @@ Typically finetunes of the base models below are supported as well.
- Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp)
- JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp) - JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp)
- JavaScript/Wasm (works in browser): [tangledgroup/llama-cpp-wasm](https://github.com/tangledgroup/llama-cpp-wasm) - JavaScript/Wasm (works in browser): [tangledgroup/llama-cpp-wasm](https://github.com/tangledgroup/llama-cpp-wasm)
- Typescript/Wasm (nicer API, available on npm): [ngxson/wllama](https://github.com/ngxson/wllama)
- Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb) - Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb)
- Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) - Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
- Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs) - Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs)

View file

@ -167,13 +167,17 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::replace(arg.begin(), arg.end(), '_', '-'); std::replace(arg.begin(), arg.end(), '_', '-');
} }
bool arg_found = false;
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.seed = std::stoul(argv[i]); params.seed = std::stoul(argv[i]);
} else if (arg == "-t" || arg == "--threads") { }
if (arg == "-t" || arg == "--threads") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -182,7 +186,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads <= 0) { if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency(); params.n_threads = std::thread::hardware_concurrency();
} }
} else if (arg == "-tb" || arg == "--threads-batch") { }
if (arg == "-tb" || arg == "--threads-batch") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -191,7 +197,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads_batch <= 0) { if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency(); params.n_threads_batch = std::thread::hardware_concurrency();
} }
} else if (arg == "-td" || arg == "--threads-draft") { }
if (arg == "-td" || arg == "--threads-draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -200,7 +208,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads_draft <= 0) { if (params.n_threads_draft <= 0) {
params.n_threads_draft = std::thread::hardware_concurrency(); params.n_threads_draft = std::thread::hardware_concurrency();
} }
} else if (arg == "-tbd" || arg == "--threads-batch-draft") { }
if (arg == "-tbd" || arg == "--threads-batch-draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -209,25 +219,37 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads_batch_draft <= 0) { if (params.n_threads_batch_draft <= 0) {
params.n_threads_batch_draft = std::thread::hardware_concurrency(); params.n_threads_batch_draft = std::thread::hardware_concurrency();
} }
} else if (arg == "-p" || arg == "--prompt") { }
if (arg == "-p" || arg == "--prompt") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.prompt = argv[i]; params.prompt = argv[i];
} else if (arg == "-e" || arg == "--escape") { }
if (arg == "-e" || arg == "--escape") {
arg_found = true;
params.escape = true; params.escape = true;
} else if (arg == "--prompt-cache") { }
if (arg == "--prompt-cache") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.path_prompt_cache = argv[i]; params.path_prompt_cache = argv[i];
} else if (arg == "--prompt-cache-all") { }
if (arg == "--prompt-cache-all") {
arg_found = true;
params.prompt_cache_all = true; params.prompt_cache_all = true;
} else if (arg == "--prompt-cache-ro") { }
if (arg == "--prompt-cache-ro") {
arg_found = true;
params.prompt_cache_ro = true; params.prompt_cache_ro = true;
} else if (arg == "-bf" || arg == "--binary-file") { }
if (arg == "-bf" || arg == "--binary-file") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -244,7 +266,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
ss << file.rdbuf(); ss << file.rdbuf();
params.prompt = ss.str(); params.prompt = ss.str();
fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]); fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]);
} else if (arg == "-f" || arg == "--file") { }
if (arg == "-f" || arg == "--file") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -261,51 +285,67 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (!params.prompt.empty() && params.prompt.back() == '\n') { if (!params.prompt.empty() && params.prompt.back() == '\n') {
params.prompt.pop_back(); params.prompt.pop_back();
} }
} else if (arg == "-n" || arg == "--n-predict") { }
if (arg == "-n" || arg == "--n-predict") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} else if (arg == "--top-k") { }
if (arg == "--top-k") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.top_k = std::stoi(argv[i]); sparams.top_k = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx-size") { }
if (arg == "-c" || arg == "--ctx-size") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
} else if (arg == "--grp-attn-n" || arg == "-gan") { }
if (arg == "--grp-attn-n" || arg == "-gan") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.grp_attn_n = std::stoi(argv[i]); params.grp_attn_n = std::stoi(argv[i]);
} else if (arg == "--grp-attn-w" || arg == "-gaw") { }
if (arg == "--grp-attn-w" || arg == "-gaw") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.grp_attn_w = std::stoi(argv[i]); params.grp_attn_w = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") { }
if (arg == "--rope-freq-base") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.rope_freq_base = std::stof(argv[i]); params.rope_freq_base = std::stof(argv[i]);
} else if (arg == "--rope-freq-scale") { }
if (arg == "--rope-freq-scale") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.rope_freq_scale = std::stof(argv[i]); params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--rope-scaling") { }
if (arg == "--rope-scaling") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -315,43 +355,57 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { invalid_param = true; break; } else { invalid_param = true; break; }
} else if (arg == "--rope-scale") { }
if (arg == "--rope-scale") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.rope_freq_scale = 1.0f/std::stof(argv[i]); params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--yarn-orig-ctx") { }
if (arg == "--yarn-orig-ctx") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_orig_ctx = std::stoi(argv[i]); params.yarn_orig_ctx = std::stoi(argv[i]);
} else if (arg == "--yarn-ext-factor") { }
if (arg == "--yarn-ext-factor") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_ext_factor = std::stof(argv[i]); params.yarn_ext_factor = std::stof(argv[i]);
} else if (arg == "--yarn-attn-factor") { }
if (arg == "--yarn-attn-factor") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_attn_factor = std::stof(argv[i]); params.yarn_attn_factor = std::stof(argv[i]);
} else if (arg == "--yarn-beta-fast") { }
if (arg == "--yarn-beta-fast") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_beta_fast = std::stof(argv[i]); params.yarn_beta_fast = std::stof(argv[i]);
} else if (arg == "--yarn-beta-slow") { }
if (arg == "--yarn-beta-slow") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_beta_slow = std::stof(argv[i]); params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--pooling") { }
if (arg == "--pooling") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -361,118 +415,156 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; } else { invalid_param = true; break; }
} else if (arg == "--defrag-thold" || arg == "-dt") { }
if (arg == "--defrag-thold" || arg == "-dt") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.defrag_thold = std::stof(argv[i]); params.defrag_thold = std::stof(argv[i]);
} else if (arg == "--samplers") { }
if (arg == "--samplers") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
const auto sampler_names = string_split(argv[i], ';'); const auto sampler_names = string_split(argv[i], ';');
sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); sparams.samplers_sequence = sampler_types_from_names(sampler_names, true);
} else if (arg == "--sampling-seq") { }
if (arg == "--sampling-seq") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.samplers_sequence = sampler_types_from_chars(argv[i]); sparams.samplers_sequence = sampler_types_from_chars(argv[i]);
} else if (arg == "--top-p") { }
if (arg == "--top-p") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.top_p = std::stof(argv[i]); sparams.top_p = std::stof(argv[i]);
} else if (arg == "--min-p") { }
if (arg == "--min-p") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.min_p = std::stof(argv[i]); sparams.min_p = std::stof(argv[i]);
} else if (arg == "--temp") { }
if (arg == "--temp") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.temp = std::stof(argv[i]); sparams.temp = std::stof(argv[i]);
sparams.temp = std::max(sparams.temp, 0.0f); sparams.temp = std::max(sparams.temp, 0.0f);
} else if (arg == "--tfs") { }
if (arg == "--tfs") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.tfs_z = std::stof(argv[i]); sparams.tfs_z = std::stof(argv[i]);
} else if (arg == "--typical") { }
if (arg == "--typical") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.typical_p = std::stof(argv[i]); sparams.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat-last-n") { }
if (arg == "--repeat-last-n") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.penalty_last_n = std::stoi(argv[i]); sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
} else if (arg == "--repeat-penalty") { }
if (arg == "--repeat-penalty") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.penalty_repeat = std::stof(argv[i]); sparams.penalty_repeat = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") { }
if (arg == "--frequency-penalty") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.penalty_freq = std::stof(argv[i]); sparams.penalty_freq = std::stof(argv[i]);
} else if (arg == "--presence-penalty") { }
if (arg == "--presence-penalty") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.penalty_present = std::stof(argv[i]); sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--dynatemp-range") { }
if (arg == "--dynatemp-range") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.dynatemp_range = std::stof(argv[i]); sparams.dynatemp_range = std::stof(argv[i]);
} else if (arg == "--dynatemp-exp") { }
if (arg == "--dynatemp-exp") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.dynatemp_exponent = std::stof(argv[i]); sparams.dynatemp_exponent = std::stof(argv[i]);
} else if (arg == "--mirostat") { }
if (arg == "--mirostat") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.mirostat = std::stoi(argv[i]); sparams.mirostat = std::stoi(argv[i]);
} else if (arg == "--mirostat-lr") { }
if (arg == "--mirostat-lr") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.mirostat_eta = std::stof(argv[i]); sparams.mirostat_eta = std::stof(argv[i]);
} else if (arg == "--mirostat-ent") { }
if (arg == "--mirostat-ent") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.mirostat_tau = std::stof(argv[i]); sparams.mirostat_tau = std::stof(argv[i]);
} else if (arg == "--cfg-negative-prompt") { }
if (arg == "--cfg-negative-prompt") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.cfg_negative_prompt = argv[i]; sparams.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-negative-prompt-file") { }
if (arg == "--cfg-negative-prompt-file") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -487,92 +579,121 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
sparams.cfg_negative_prompt.pop_back(); sparams.cfg_negative_prompt.pop_back();
} }
} else if (arg == "--cfg-scale") { }
if (arg == "--cfg-scale") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.cfg_scale = std::stof(argv[i]); sparams.cfg_scale = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") { }
if (arg == "-b" || arg == "--batch-size") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_batch = std::stoi(argv[i]); params.n_batch = std::stoi(argv[i]);
} else if (arg == "-ub" || arg == "--ubatch-size") { }
if (arg == "-ub" || arg == "--ubatch-size") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_ubatch = std::stoi(argv[i]); params.n_ubatch = std::stoi(argv[i]);
} else if (arg == "--keep") { }
if (arg == "--keep") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_keep = std::stoi(argv[i]); params.n_keep = std::stoi(argv[i]);
} else if (arg == "--draft") { }
if (arg == "--draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_draft = std::stoi(argv[i]); params.n_draft = std::stoi(argv[i]);
} else if (arg == "--chunks") { }
if (arg == "--chunks") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_chunks = std::stoi(argv[i]); params.n_chunks = std::stoi(argv[i]);
} else if (arg == "-np" || arg == "--parallel") { }
if (arg == "-np" || arg == "--parallel") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_parallel = std::stoi(argv[i]); params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-ns" || arg == "--sequences") { }
if (arg == "-ns" || arg == "--sequences") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_sequences = std::stoi(argv[i]); params.n_sequences = std::stoi(argv[i]);
} else if (arg == "--p-split" || arg == "-ps") { }
if (arg == "--p-split" || arg == "-ps") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.p_split = std::stof(argv[i]); params.p_split = std::stof(argv[i]);
} else if (arg == "-m" || arg == "--model") { }
if (arg == "-m" || arg == "--model") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.model = argv[i]; params.model = argv[i];
} else if (arg == "-mu" || arg == "--model-url") { }
if (arg == "-mu" || arg == "--model-url") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.model_url = argv[i]; params.model_url = argv[i];
} else if (arg == "-md" || arg == "--model-draft") { }
if (arg == "-md" || arg == "--model-draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.model_draft = argv[i]; params.model_draft = argv[i];
} else if (arg == "-a" || arg == "--alias") { }
if (arg == "-a" || arg == "--alias") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.model_alias = argv[i]; params.model_alias = argv[i];
} else if (arg == "--lora") { }
if (arg == "--lora") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.emplace_back(argv[i], 1.0f); params.lora_adapter.emplace_back(argv[i], 1.0f);
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-scaled") { }
if (arg == "--lora-scaled") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -584,19 +705,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
} }
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-base") { }
if (arg == "--lora-base") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_base = argv[i]; params.lora_base = argv[i];
} else if (arg == "--control-vector") { }
if (arg == "--control-vector") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.control_vectors.push_back({ 1.0f, argv[i], }); params.control_vectors.push_back({ 1.0f, argv[i], });
} else if (arg == "--control-vector-scaled") { }
if (arg == "--control-vector-scaled") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -607,7 +734,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.control_vectors.push_back({ std::stof(argv[i]), fname, }); params.control_vectors.push_back({ std::stof(argv[i]), fname, });
} else if (arg == "--control-vector-layer-range") { }
if (arg == "--control-vector-layer-range") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -618,49 +747,85 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.control_vector_layer_end = std::stoi(argv[i]); params.control_vector_layer_end = std::stoi(argv[i]);
} else if (arg == "--mmproj") { }
if (arg == "--mmproj") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.mmproj = argv[i]; params.mmproj = argv[i];
} else if (arg == "--image") { }
if (arg == "--image") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.image = argv[i]; params.image = argv[i];
} else if (arg == "-i" || arg == "--interactive") { }
if (arg == "-i" || arg == "--interactive") {
arg_found = true;
params.interactive = true; params.interactive = true;
} else if (arg == "--embedding") { }
if (arg == "--embedding") {
arg_found = true;
params.embedding = true; params.embedding = true;
} else if (arg == "--interactive-first") { }
if (arg == "--interactive-first") {
arg_found = true;
params.interactive_first = true; params.interactive_first = true;
} else if (arg == "-ins" || arg == "--instruct") { }
if (arg == "-ins" || arg == "--instruct") {
arg_found = true;
params.instruct = true; params.instruct = true;
} else if (arg == "-cml" || arg == "--chatml") { }
if (arg == "-cml" || arg == "--chatml") {
arg_found = true;
params.chatml = true; params.chatml = true;
} else if (arg == "--infill") { }
if (arg == "--infill") {
arg_found = true;
params.infill = true; params.infill = true;
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") { }
if (arg == "-dkvc" || arg == "--dump-kv-cache") {
arg_found = true;
params.dump_kv_cache = true; params.dump_kv_cache = true;
} else if (arg == "-nkvo" || arg == "--no-kv-offload") { }
if (arg == "-nkvo" || arg == "--no-kv-offload") {
arg_found = true;
params.no_kv_offload = true; params.no_kv_offload = true;
} else if (arg == "-ctk" || arg == "--cache-type-k") { }
if (arg == "-ctk" || arg == "--cache-type-k") {
arg_found = true;
params.cache_type_k = argv[++i]; params.cache_type_k = argv[++i];
} else if (arg == "-ctv" || arg == "--cache-type-v") { }
if (arg == "-ctv" || arg == "--cache-type-v") {
arg_found = true;
params.cache_type_v = argv[++i]; params.cache_type_v = argv[++i];
} else if (arg == "--multiline-input") { }
if (arg == "--multiline-input") {
arg_found = true;
params.multiline_input = true; params.multiline_input = true;
} else if (arg == "--simple-io") { }
if (arg == "--simple-io") {
arg_found = true;
params.simple_io = true; params.simple_io = true;
} else if (arg == "-cb" || arg == "--cont-batching") { }
if (arg == "-cb" || arg == "--cont-batching") {
arg_found = true;
params.cont_batching = true; params.cont_batching = true;
} else if (arg == "--color") { }
if (arg == "--color") {
arg_found = true;
params.use_color = true; params.use_color = true;
} else if (arg == "--mlock") { }
if (arg == "--mlock") {
arg_found = true;
params.use_mlock = true; params.use_mlock = true;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { }
if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -670,7 +835,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
} }
} else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { }
if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -680,7 +847,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
} }
} else if (arg == "--main-gpu" || arg == "-mg") { }
if (arg == "--main-gpu" || arg == "-mg") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -689,7 +858,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
#ifndef GGML_USE_CUBLAS_SYCL #ifndef GGML_USE_CUBLAS_SYCL
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the main GPU has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the main GPU has no effect.\n");
#endif // GGML_USE_CUBLAS_SYCL #endif // GGML_USE_CUBLAS_SYCL
} else if (arg == "--split-mode" || arg == "-sm") { }
if (arg == "--split-mode" || arg == "-sm") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -713,7 +884,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the split mode has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the split mode has no effect.\n");
#endif // GGML_USE_CUBLAS_SYCL #endif // GGML_USE_CUBLAS_SYCL
} else if (arg == "--tensor-split" || arg == "-ts") { }
if (arg == "--tensor-split" || arg == "-ts") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -738,9 +911,13 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
#ifndef GGML_USE_CUBLAS_SYCL_VULKAN #ifndef GGML_USE_CUBLAS_SYCL_VULKAN
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL/Vulkan. Setting a tensor split has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL/Vulkan. Setting a tensor split has no effect.\n");
#endif // GGML_USE_CUBLAS_SYCL #endif // GGML_USE_CUBLAS_SYCL
} else if (arg == "--no-mmap") { }
if (arg == "--no-mmap") {
arg_found = true;
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--numa") { }
if (arg == "--numa") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -750,17 +927,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
else { invalid_param = true; break; } else { invalid_param = true; break; }
} else if (arg == "--verbose-prompt") { }
if (arg == "--verbose-prompt") {
arg_found = true;
params.verbose_prompt = true; params.verbose_prompt = true;
} else if (arg == "--no-display-prompt") { }
if (arg == "--no-display-prompt") {
arg_found = true;
params.display_prompt = false; params.display_prompt = false;
} else if (arg == "-r" || arg == "--reverse-prompt") { }
if (arg == "-r" || arg == "--reverse-prompt") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.antiprompt.emplace_back(argv[i]); params.antiprompt.emplace_back(argv[i]);
} else if (arg == "-ld" || arg == "--logdir") { }
if (arg == "-ld" || arg == "--logdir") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -770,63 +955,93 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.logdir.back() != DIRECTORY_SEPARATOR) { if (params.logdir.back() != DIRECTORY_SEPARATOR) {
params.logdir += DIRECTORY_SEPARATOR; params.logdir += DIRECTORY_SEPARATOR;
} }
} else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { }
if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.logits_file = argv[i]; params.logits_file = argv[i];
} else if (arg == "--perplexity" || arg == "--all-logits") { }
if (arg == "--perplexity" || arg == "--all-logits") {
arg_found = true;
params.logits_all = true; params.logits_all = true;
} else if (arg == "--ppl-stride") { }
if (arg == "--ppl-stride") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.ppl_stride = std::stoi(argv[i]); params.ppl_stride = std::stoi(argv[i]);
} else if (arg == "-ptc" || arg == "--print-token-count") { }
if (arg == "-ptc" || arg == "--print-token-count") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_print = std::stoi(argv[i]); params.n_print = std::stoi(argv[i]);
} else if (arg == "--ppl-output-type") { }
if (arg == "--ppl-output-type") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.ppl_output_type = std::stoi(argv[i]); params.ppl_output_type = std::stoi(argv[i]);
} else if (arg == "--hellaswag") { }
if (arg == "--hellaswag") {
arg_found = true;
params.hellaswag = true; params.hellaswag = true;
} else if (arg == "--hellaswag-tasks") { }
if (arg == "--hellaswag-tasks") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.hellaswag_tasks = std::stoi(argv[i]); params.hellaswag_tasks = std::stoi(argv[i]);
} else if (arg == "--winogrande") { }
if (arg == "--winogrande") {
arg_found = true;
params.winogrande = true; params.winogrande = true;
} else if (arg == "--winogrande-tasks") { }
if (arg == "--winogrande-tasks") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.winogrande_tasks = std::stoi(argv[i]); params.winogrande_tasks = std::stoi(argv[i]);
} else if (arg == "--multiple-choice") { }
if (arg == "--multiple-choice") {
arg_found = true;
params.multiple_choice = true; params.multiple_choice = true;
} else if (arg == "--multiple-choice-tasks") { }
if (arg == "--multiple-choice-tasks") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.multiple_choice_tasks = std::stoi(argv[i]); params.multiple_choice_tasks = std::stoi(argv[i]);
} else if (arg == "--kl-divergence") { }
if (arg == "--kl-divergence") {
arg_found = true;
params.kl_divergence = true; params.kl_divergence = true;
} else if (arg == "--ignore-eos") { }
if (arg == "--ignore-eos") {
arg_found = true;
params.ignore_eos = true; params.ignore_eos = true;
} else if (arg == "--no-penalize-nl") { }
if (arg == "--no-penalize-nl") {
arg_found = true;
sparams.penalize_nl = false; sparams.penalize_nl = false;
} else if (arg == "-l" || arg == "--logit-bias") { }
if (arg == "-l" || arg == "--logit-bias") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -845,36 +1060,51 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
} else if (arg == "-h" || arg == "--help") { }
if (arg == "-h" || arg == "--help") {
arg_found = true;
return false; return false;
}
} else if (arg == "--version") { if (arg == "--version") {
arg_found = true;
fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
exit(0); exit(0);
} else if (arg == "--random-prompt") { }
if (arg == "--random-prompt") {
arg_found = true;
params.random_prompt = true; params.random_prompt = true;
} else if (arg == "--in-prefix-bos") { }
if (arg == "--in-prefix-bos") {
arg_found = true;
params.input_prefix_bos = true; params.input_prefix_bos = true;
} else if (arg == "--in-prefix") { }
if (arg == "--in-prefix") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.input_prefix = argv[i]; params.input_prefix = argv[i];
} else if (arg == "--in-suffix") { }
if (arg == "--in-suffix") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.input_suffix = argv[i]; params.input_suffix = argv[i];
} else if (arg == "--grammar") { }
if (arg == "--grammar") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.grammar = argv[i]; sparams.grammar = argv[i];
} else if (arg == "--grammar-file") { }
if (arg == "--grammar-file") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -890,7 +1120,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::istreambuf_iterator<char>(), std::istreambuf_iterator<char>(),
std::back_inserter(sparams.grammar) std::back_inserter(sparams.grammar)
); );
} else if (arg == "--override-kv") { }
if (arg == "--override-kv") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -933,10 +1165,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.kv_overrides.push_back(kvo); params.kv_overrides.push_back(kvo);
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters // Parse args for logging parameters
} else if ( log_param_single_parse( argv[i] ) ) { }
if ( log_param_single_parse( argv[i] ) ) {
arg_found = true;
// Do nothing, log_param_single_parse automatically does it's thing // Do nothing, log_param_single_parse automatically does it's thing
// and returns if a match was found and parsed. // and returns if a match was found and parsed.
} else if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) { }
if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) {
arg_found = true;
// We have a matching known parameter requiring an argument, // We have a matching known parameter requiring an argument,
// now we need to check if there is anything after this argv // now we need to check if there is anything after this argv
// and flag invalid_param or parse it. // and flag invalid_param or parse it.
@ -950,7 +1186,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
} }
// End of Parse args for logging parameters // End of Parse args for logging parameters
#endif // LOG_DISABLE_LOGS #endif // LOG_DISABLE_LOGS
} else { }
if (!arg_found) {
throw std::invalid_argument("error: unknown argument: " + arg); throw std::invalid_argument("error: unknown argument: " + arg);
} }
} }

62
examples/gritlm/README.md Normal file
View file

@ -0,0 +1,62 @@
## Generative Representational Instruction Tuning (GRIT) Example
[gritlm] a model which can generate embeddings as well as "normal" text
generation depending on the instructions in the prompt.
* Paper: https://arxiv.org/pdf/2402.09906.pdf
### Retrieval-Augmented Generation (RAG) use case
One use case for `gritlm` is to use it with RAG. If we recall how RAG works is
that we take documents that we want to use as context, to ground the large
language model (LLM), and we create token embeddings for them. We then store
these token embeddings in a vector database.
When we perform a query, prompt the LLM, we will first create token embeddings
for the query and then search the vector database to retrieve the most
similar vectors, and return those documents so they can be passed to the LLM as
context. Then the query and the context will be passed to the LLM which will
have to _again_ create token embeddings for the query. But because gritlm is used
the first query can be cached and the second query tokenization generation does
not have to be performed at all.
### Running the example
Download a Grit model:
```console
$ scripts/hf.sh --repo cohesionet/GritLM-7B_gguf --file gritlm-7b_q4_1.gguf
```
Run the example using the downloaded model:
```console
$ ./gritlm -m gritlm-7b_q4_1.gguf
Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "A purely peer-to-peer version of electronic cash w" is: 0.605
Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "All text-based language problems can be reduced to" is: 0.103
Cosine similarity between "Generative Representational Instruction Tuning" and "A purely peer-to-peer version of electronic cash w" is: 0.112
Cosine similarity between "Generative Representational Instruction Tuning" and "All text-based language problems can be reduced to" is: 0.547
Oh, brave adventurer, who dared to climb
The lofty peak of Mt. Fuji in the night,
When shadows lurk and ghosts do roam,
And darkness reigns, a fearsome sight.
Thou didst set out, with heart aglow,
To conquer this mountain, so high,
And reach the summit, where the stars do glow,
And the moon shines bright, up in the sky.
Through the mist and fog, thou didst press on,
With steadfast courage, and a steadfast will,
Through the darkness, thou didst not be gone,
But didst climb on, with a steadfast skill.
At last, thou didst reach the summit's crest,
And gazed upon the world below,
And saw the beauty of the night's best,
And felt the peace, that only nature knows.
Oh, brave adventurer, who dared to climb
The lofty peak of Mt. Fuji in the night,
Thou art a hero, in the eyes of all,
For thou didst conquer this mountain, so bright.
```
[gritlm]: https://github.com/ContextualAI/gritlm

95
ggml.c
View file

@ -931,6 +931,101 @@ inline static float vaddvq_f32(float32x4_t v) {
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
#endif #endif
#elif defined(__AVX512F__)
#define GGML_SIMD
// F32 AVX512
#define GGML_F32_STEP 64
#define GGML_F32_EPR 16
#define GGML_F32x16 __m512
#define GGML_F32x16_ZERO _mm512_setzero_ps()
#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
#define GGML_F32x16_LOAD _mm512_loadu_ps
#define GGML_F32x16_STORE _mm512_storeu_ps
// _mm512_fmadd_ps is defined in AVX512F so no guard is required
#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
#define GGML_F32x16_ADD _mm512_add_ps
#define GGML_F32x16_MUL _mm512_mul_ps
#define GGML_F32x16_REDUCE(res, x) \
do { \
int offset = GGML_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm512_add_ps(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm512_add_ps(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm512_add_ps(x[i], x[offset+i]); \
} \
res = _mm512_reduce_add_ps(x[0]); \
} while (0)
// TODO: is this optimal ?
#define GGML_F32_VEC GGML_F32x16
#define GGML_F32_VEC_ZERO GGML_F32x16_ZERO
#define GGML_F32_VEC_SET1 GGML_F32x16_SET1
#define GGML_F32_VEC_LOAD GGML_F32x16_LOAD
#define GGML_F32_VEC_STORE GGML_F32x16_STORE
#define GGML_F32_VEC_FMA GGML_F32x16_FMA
#define GGML_F32_VEC_ADD GGML_F32x16_ADD
#define GGML_F32_VEC_MUL GGML_F32x16_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
// F16 AVX512
// F16 AVX
#define GGML_F16_STEP 64
#define GGML_F16_EPR 16
// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
#define GGML_F32Cx16 __m512
#define GGML_F32Cx16_ZERO _mm512_setzero_ps()
#define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x)
// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
// so F16C guard isn't required
#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
#define GGML_F32Cx16_ADD _mm512_add_ps
#define GGML_F32Cx16_MUL _mm512_mul_ps
#define GGML_F32Cx16_REDUCE(res, x) \
do { \
int offset = GGML_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm512_add_ps(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm512_add_ps(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm512_add_ps(x[i], x[offset+i]); \
} \
res = _mm512_reduce_add_ps(x[0]); \
} while (0)
#define GGML_F16_VEC GGML_F32Cx16
#define GGML_F16_VEC_ZERO GGML_F32Cx16_ZERO
#define GGML_F16_VEC_SET1 GGML_F32Cx16_SET1
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx16_LOAD(p)
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
#define GGML_F16_VEC_FMA GGML_F32Cx16_FMA
#define GGML_F16_VEC_ADD GGML_F32Cx16_ADD
#define GGML_F16_VEC_MUL GGML_F32Cx16_MUL
#define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE
#elif defined(__AVX__) #elif defined(__AVX__)
#define GGML_SIMD #define GGML_SIMD