add parameters for embeddings

--embd-normalize
--embd-output-format
--embd-separator
description in the README.md
This commit is contained in:
Yann Follet 2024-05-22 09:29:15 +00:00
parent 95fb0aefab
commit 625bdb5225
4 changed files with 265 additions and 419 deletions

View file

@ -281,24 +281,20 @@ bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> &
return true; return true;
} }
#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
llama_sampling_params & sparams = params.sparams; llama_sampling_params & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
// This is temporary, in the future the samplign state will be moved fully to llama_sampling_context. // This is temporary, in the future the samplign state will be moved fully to llama_sampling_context.
params.seed = std::stoul(argv[i]); params.seed = std::stoul(argv[i]);
sparams.seed = std::stoul(argv[i]); sparams.seed = std::stoul(argv[i]);
return true; return true;
} }
if (arg == "-t" || arg == "--threads") { if (arg == "-t" || arg == "--threads") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_threads = std::stoi(argv[i]); params.n_threads = std::stoi(argv[i]);
if (params.n_threads <= 0) { if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency(); params.n_threads = std::thread::hardware_concurrency();
@ -306,10 +302,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-tb" || arg == "--threads-batch") { if (arg == "-tb" || arg == "--threads-batch") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_threads_batch = std::stoi(argv[i]); params.n_threads_batch = std::stoi(argv[i]);
if (params.n_threads_batch <= 0) { if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency(); params.n_threads_batch = std::thread::hardware_concurrency();
@ -317,10 +310,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-td" || arg == "--threads-draft") { if (arg == "-td" || arg == "--threads-draft") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_threads_draft = std::stoi(argv[i]); params.n_threads_draft = std::stoi(argv[i]);
if (params.n_threads_draft <= 0) { if (params.n_threads_draft <= 0) {
params.n_threads_draft = std::thread::hardware_concurrency(); params.n_threads_draft = std::thread::hardware_concurrency();
@ -328,10 +318,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-tbd" || arg == "--threads-batch-draft") { if (arg == "-tbd" || arg == "--threads-batch-draft") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_threads_batch_draft = std::stoi(argv[i]); params.n_threads_batch_draft = std::stoi(argv[i]);
if (params.n_threads_batch_draft <= 0) { if (params.n_threads_batch_draft <= 0) {
params.n_threads_batch_draft = std::thread::hardware_concurrency(); params.n_threads_batch_draft = std::thread::hardware_concurrency();
@ -339,10 +326,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-p" || arg == "--prompt") { if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.prompt = argv[i]; params.prompt = argv[i];
return true; return true;
} }
@ -351,10 +335,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--prompt-cache") { if (arg == "--prompt-cache") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.path_prompt_cache = argv[i]; params.path_prompt_cache = argv[i];
return true; return true;
} }
@ -367,10 +348,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-bf" || arg == "--binary-file") { if (arg == "-bf" || arg == "--binary-file") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::ifstream file(argv[i], std::ios::binary); std::ifstream file(argv[i], std::ios::binary);
if (!file) { if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -386,10 +364,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-f" || arg == "--file") { if (arg == "-f" || arg == "--file") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::ifstream file(argv[i]); std::ifstream file(argv[i]);
if (!file) { if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -405,66 +380,42 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-n" || arg == "--n-predict") { if (arg == "-n" || arg == "--n-predict") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--top-k") { if (arg == "--top-k") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.top_k = std::stoi(argv[i]); sparams.top_k = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "-c" || arg == "--ctx-size") { if (arg == "-c" || arg == "--ctx-size") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--grp-attn-n" || arg == "-gan") { if (arg == "--grp-attn-n" || arg == "-gan") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.grp_attn_n = std::stoi(argv[i]); params.grp_attn_n = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--grp-attn-w" || arg == "-gaw") { if (arg == "--grp-attn-w" || arg == "-gaw") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.grp_attn_w = std::stoi(argv[i]); params.grp_attn_w = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--rope-freq-base") { if (arg == "--rope-freq-base") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.rope_freq_base = std::stof(argv[i]); params.rope_freq_base = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--rope-freq-scale") { if (arg == "--rope-freq-scale") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.rope_freq_scale = std::stof(argv[i]); params.rope_freq_scale = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--rope-scaling") { if (arg == "--rope-scaling") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::string value(argv[i]); std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
@ -473,58 +424,37 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--rope-scale") { if (arg == "--rope-scale") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.rope_freq_scale = 1.0f / std::stof(argv[i]); params.rope_freq_scale = 1.0f / std::stof(argv[i]);
return true; return true;
} }
if (arg == "--yarn-orig-ctx") { if (arg == "--yarn-orig-ctx") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.yarn_orig_ctx = std::stoi(argv[i]); params.yarn_orig_ctx = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--yarn-ext-factor") { if (arg == "--yarn-ext-factor") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.yarn_ext_factor = std::stof(argv[i]); params.yarn_ext_factor = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--yarn-attn-factor") { if (arg == "--yarn-attn-factor") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.yarn_attn_factor = std::stof(argv[i]); params.yarn_attn_factor = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--yarn-beta-fast") { if (arg == "--yarn-beta-fast") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.yarn_beta_fast = std::stof(argv[i]); params.yarn_beta_fast = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--yarn-beta-slow") { if (arg == "--yarn-beta-slow") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.yarn_beta_slow = std::stof(argv[i]); params.yarn_beta_slow = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--pooling") { if (arg == "--pooling") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::string value(argv[i]); std::string value(argv[i]);
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
@ -533,157 +463,100 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--defrag-thold" || arg == "-dt") { if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.defrag_thold = std::stof(argv[i]); params.defrag_thold = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--samplers") { if (arg == "--samplers") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
const auto sampler_names = string_split(argv[i], ';'); const auto sampler_names = string_split(argv[i], ';');
sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); sparams.samplers_sequence = sampler_types_from_names(sampler_names, true);
return true; return true;
} }
if (arg == "--sampling-seq") { if (arg == "--sampling-seq") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.samplers_sequence = sampler_types_from_chars(argv[i]); sparams.samplers_sequence = sampler_types_from_chars(argv[i]);
return true; return true;
} }
if (arg == "--top-p") { if (arg == "--top-p") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.top_p = std::stof(argv[i]); sparams.top_p = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--min-p") { if (arg == "--min-p") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.min_p = std::stof(argv[i]); sparams.min_p = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--temp") { if (arg == "--temp") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.temp = std::stof(argv[i]); sparams.temp = std::stof(argv[i]);
sparams.temp = std::max(sparams.temp, 0.0f); sparams.temp = std::max(sparams.temp, 0.0f);
return true; return true;
} }
if (arg == "--tfs") { if (arg == "--tfs") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.tfs_z = std::stof(argv[i]); sparams.tfs_z = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--typical") { if (arg == "--typical") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.typical_p = std::stof(argv[i]); sparams.typical_p = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--repeat-last-n") { if (arg == "--repeat-last-n") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.penalty_last_n = std::stoi(argv[i]); sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
return true; return true;
} }
if (arg == "--repeat-penalty") { if (arg == "--repeat-penalty") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.penalty_repeat = std::stof(argv[i]); sparams.penalty_repeat = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--frequency-penalty") { if (arg == "--frequency-penalty") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.penalty_freq = std::stof(argv[i]); sparams.penalty_freq = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--presence-penalty") { if (arg == "--presence-penalty") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.penalty_present = std::stof(argv[i]); sparams.penalty_present = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--dynatemp-range") { if (arg == "--dynatemp-range") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.dynatemp_range = std::stof(argv[i]); sparams.dynatemp_range = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--dynatemp-exp") { if (arg == "--dynatemp-exp") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.dynatemp_exponent = std::stof(argv[i]); sparams.dynatemp_exponent = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--mirostat") { if (arg == "--mirostat") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.mirostat = std::stoi(argv[i]); sparams.mirostat = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--mirostat-lr") { if (arg == "--mirostat-lr") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.mirostat_eta = std::stof(argv[i]); sparams.mirostat_eta = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--mirostat-ent") { if (arg == "--mirostat-ent") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.mirostat_tau = std::stof(argv[i]); sparams.mirostat_tau = std::stof(argv[i]);
return true; return true;
} }
if (arg == "--cfg-negative-prompt") { if (arg == "--cfg-negative-prompt") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.cfg_negative_prompt = argv[i]; sparams.cfg_negative_prompt = argv[i];
return true; return true;
} }
if (arg == "--cfg-negative-prompt-file") { if (arg == "--cfg-negative-prompt-file") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::ifstream file(argv[i]); std::ifstream file(argv[i]);
if (!file) { if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -697,203 +570,125 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--cfg-scale") { if (arg == "--cfg-scale") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.cfg_scale = std::stof(argv[i]); sparams.cfg_scale = std::stof(argv[i]);
return true; return true;
} }
if (arg == "-b" || arg == "--batch-size") { if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_batch = std::stoi(argv[i]); params.n_batch = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "-ub" || arg == "--ubatch-size") { if (arg == "-ub" || arg == "--ubatch-size") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_ubatch = std::stoi(argv[i]); params.n_ubatch = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--keep") { if (arg == "--keep") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_keep = std::stoi(argv[i]); params.n_keep = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--draft") { if (arg == "--draft") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_draft = std::stoi(argv[i]); params.n_draft = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--chunks") { if (arg == "--chunks") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_chunks = std::stoi(argv[i]); params.n_chunks = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "-np" || arg == "--parallel") { if (arg == "-np" || arg == "--parallel") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_parallel = std::stoi(argv[i]); params.n_parallel = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "-ns" || arg == "--sequences") { if (arg == "-ns" || arg == "--sequences") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_sequences = std::stoi(argv[i]); params.n_sequences = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--p-split" || arg == "-ps") { if (arg == "--p-split" || arg == "-ps") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.p_split = std::stof(argv[i]); params.p_split = std::stof(argv[i]);
return true; return true;
} }
if (arg == "-m" || arg == "--model") { if (arg == "-m" || arg == "--model") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.model = argv[i]; params.model = argv[i];
return true; return true;
} }
if (arg == "-md" || arg == "--model-draft") { if (arg == "-md" || arg == "--model-draft") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.model_draft = argv[i]; params.model_draft = argv[i];
return true; return true;
} }
if (arg == "-a" || arg == "--alias") { if (arg == "-a" || arg == "--alias") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.model_alias = argv[i]; params.model_alias = argv[i];
return true; return true;
} }
if (arg == "-mu" || arg == "--model-url") { if (arg == "-mu" || arg == "--model-url") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.model_url = argv[i]; params.model_url = argv[i];
return true; return true;
} }
if (arg == "-hfr" || arg == "--hf-repo") { if (arg == "-hfr" || arg == "--hf-repo") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.hf_repo = argv[i]; params.hf_repo = argv[i];
return true; return true;
} }
if (arg == "-hff" || arg == "--hf-file") { if (arg == "-hff" || arg == "--hf-file") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.hf_file = argv[i]; params.hf_file = argv[i];
return true; return true;
} }
if (arg == "--lora") { if (arg == "--lora") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.lora_adapter.emplace_back(argv[i], 1.0f); params.lora_adapter.emplace_back(argv[i], 1.0f);
params.use_mmap = false; params.use_mmap = false;
return true; return true;
} }
if (arg == "--lora-scaled") { if (arg == "--lora-scaled") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
const char* lora_adapter = argv[i]; const char* lora_adapter = argv[i];
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.use_mmap = false; params.use_mmap = false;
return true; return true;
} }
if (arg == "--lora-base") { if (arg == "--lora-base") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.lora_base = argv[i]; params.lora_base = argv[i];
return true; return true;
} }
if (arg == "--control-vector") { if (arg == "--control-vector") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.control_vectors.push_back({ 1.0f, argv[i], }); params.control_vectors.push_back({ 1.0f, argv[i], });
return true; return true;
} }
if (arg == "--control-vector-scaled") { if (arg == "--control-vector-scaled") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
const char* fname = argv[i]; const char* fname = argv[i];
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.control_vectors.push_back({ std::stof(argv[i]), fname, }); params.control_vectors.push_back({ std::stof(argv[i]), fname, });
return true; return true;
} }
if (arg == "--control-vector-layer-range") { if (arg == "--control-vector-layer-range") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.control_vector_layer_start = std::stoi(argv[i]); params.control_vector_layer_start = std::stoi(argv[i]);
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.control_vector_layer_end = std::stoi(argv[i]); params.control_vector_layer_end = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--mmproj") { if (arg == "--mmproj") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.mmproj = argv[i]; params.mmproj = argv[i];
return true; return true;
} }
if (arg == "--image") { if (arg == "--image") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.image.emplace_back(argv[i]); params.image.emplace_back(argv[i]);
return true; return true;
} }
@ -909,6 +704,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.embedding = true; params.embedding = true;
return true; return true;
} }
if (arg == "--embd-normalize") {
CHECK_ARG
params.embd_normalize = std::stoi(argv[i]);
return true;
}
if (arg == "--embd-output-format") {
CHECK_ARG
params.embd_out = argv[i];
return true;
}
if (arg == "--embd-separator") {
CHECK_ARG
params.embd_sep = argv[i];
return true;
}
if (arg == "--interactive-first") { if (arg == "--interactive-first") {
params.interactive_first = true; params.interactive_first = true;
return true; return true;
@ -970,10 +780,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_gpu_layers = std::stoi(argv[i]); params.n_gpu_layers = std::stoi(argv[i]);
if (!llama_supports_gpu_offload()) { if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
@ -982,10 +789,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_gpu_layers_draft = std::stoi(argv[i]); params.n_gpu_layers_draft = std::stoi(argv[i]);
if (!llama_supports_gpu_offload()) { if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
@ -994,10 +798,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--main-gpu" || arg == "-mg") { if (arg == "--main-gpu" || arg == "-mg") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.main_gpu = std::stoi(argv[i]); params.main_gpu = std::stoi(argv[i]);
#ifndef GGML_USE_CUDA_SYCL #ifndef GGML_USE_CUDA_SYCL
fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL. Setting the main GPU has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL. Setting the main GPU has no effect.\n");
@ -1005,10 +806,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--split-mode" || arg == "-sm") { if (arg == "--split-mode" || arg == "-sm") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::string arg_next = argv[i]; std::string arg_next = argv[i];
if (arg_next == "none") { if (arg_next == "none") {
params.split_mode = LLAMA_SPLIT_MODE_NONE; params.split_mode = LLAMA_SPLIT_MODE_NONE;
@ -1033,10 +831,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--tensor-split" || arg == "-ts") { if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::string arg_next = argv[i]; std::string arg_next = argv[i];
// split string by , and / // split string by , and /
@ -1061,10 +856,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--rpc") { if (arg == "--rpc") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.rpc_servers = argv[i]; params.rpc_servers = argv[i];
return true; return true;
} }
@ -1073,10 +865,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--numa") { if (arg == "--numa") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::string value(argv[i]); std::string value(argv[i]);
/**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
@ -1093,18 +882,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-r" || arg == "--reverse-prompt") { if (arg == "-r" || arg == "--reverse-prompt") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.antiprompt.emplace_back(argv[i]); params.antiprompt.emplace_back(argv[i]);
return true; return true;
} }
if (arg == "-ld" || arg == "--logdir") { if (arg == "-ld" || arg == "--logdir") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.logdir = argv[i]; params.logdir = argv[i];
if (params.logdir.back() != DIRECTORY_SEPARATOR) { if (params.logdir.back() != DIRECTORY_SEPARATOR) {
@ -1113,26 +896,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-lcs" || arg == "--lookup-cache-static") { if (arg == "-lcs" || arg == "--lookup-cache-static") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.lookup_cache_static = argv[i]; params.lookup_cache_static = argv[i];
return true; return true;
} }
if (arg == "-lcd" || arg == "--lookup-cache-dynamic") { if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.lookup_cache_dynamic = argv[i]; params.lookup_cache_dynamic = argv[i];
return true; return true;
} }
if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.logits_file = argv[i]; params.logits_file = argv[i];
return true; return true;
} }
@ -1141,18 +915,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--ppl-stride") { if (arg == "--ppl-stride") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.ppl_stride = std::stoi(argv[i]); params.ppl_stride = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "-ptc" || arg == "--print-token-count") { if (arg == "-ptc" || arg == "--print-token-count") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.n_print = std::stoi(argv[i]); params.n_print = std::stoi(argv[i]);
return true; return true;
} }
@ -1161,10 +929,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--ppl-output-type") { if (arg == "--ppl-output-type") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.ppl_output_type = std::stoi(argv[i]); params.ppl_output_type = std::stoi(argv[i]);
return true; return true;
} }
@ -1173,10 +938,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--hellaswag-tasks") { if (arg == "--hellaswag-tasks") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.hellaswag_tasks = std::stoi(argv[i]); params.hellaswag_tasks = std::stoi(argv[i]);
return true; return true;
} }
@ -1185,10 +947,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--winogrande-tasks") { if (arg == "--winogrande-tasks") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.winogrande_tasks = std::stoi(argv[i]); params.winogrande_tasks = std::stoi(argv[i]);
return true; return true;
} }
@ -1197,10 +956,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--multiple-choice-tasks") { if (arg == "--multiple-choice-tasks") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.multiple_choice_tasks = std::stoi(argv[i]); params.multiple_choice_tasks = std::stoi(argv[i]);
return true; return true;
} }
@ -1217,10 +973,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-l" || arg == "--logit-bias") { if (arg == "-l" || arg == "--logit-bias") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::stringstream ss(argv[i]); std::stringstream ss(argv[i]);
llama_token key; llama_token key;
char sign; char sign;
@ -1257,34 +1010,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "--in-prefix") { if (arg == "--in-prefix") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.input_prefix = argv[i]; params.input_prefix = argv[i];
return true; return true;
} }
if (arg == "--in-suffix") { if (arg == "--in-suffix") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
params.input_suffix = argv[i]; params.input_suffix = argv[i];
return true; return true;
} }
if (arg == "--grammar") { if (arg == "--grammar") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.grammar = argv[i]; sparams.grammar = argv[i];
return true; return true;
} }
if (arg == "--grammar-file") { if (arg == "--grammar-file") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
std::ifstream file(argv[i]); std::ifstream file(argv[i]);
if (!file) { if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -1299,18 +1040,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true; return true;
} }
if (arg == "-j" || arg == "--json-schema") { if (arg == "-j" || arg == "--json-schema") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
return true; return true;
} }
if (arg == "--override-kv") { if (arg == "--override-kv") {
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
if (!parse_kv_override(argv[i], params.kv_overrides)) { if (!parse_kv_override(argv[i], params.kv_overrides)) {
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
invalid_param = true; invalid_param = true;
@ -1329,10 +1064,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
// We have a matching known parameter requiring an argument, // We have a matching known parameter requiring an argument,
// now we need to check if there is anything after this argv // now we need to check if there is anything after this argv
// and flag invalid_param or parse it. // and flag invalid_param or parse it.
if (++i >= argc) { CHECK_ARG
invalid_param = true;
return true;
}
if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) { if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) {
invalid_param = true; invalid_param = true;
return true; return true;
@ -2855,14 +2587,34 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
printf("\n=== Done dumping\n"); printf("\n=== Done dumping\n");
} }
void llama_embd_normalize(const float * inp, float * out, int n) { void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
double sum = 0.0; double sum = 0.0;
for (int i = 0; i < n; i++) {
sum += inp[i] * inp[i]; switch (embd_norm) {
case -1: // no normalisation
sum = 1.0;
break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
}
sum /= 32760.0; // make an int16 range
break;
case 2: // euclidean
for (int i = 0; i < n; i++) {
sum += inp[i] * inp[i];
}
sum = std::sqrt(sum);
break;
default: // p-norm (euclidean is p-norm p=2)
for (int i = 0; i < n; i++) {
sum += std::pow(std::abs(inp[i]), embd_norm);
}
sum = std::pow(sum, 1.0 / embd_norm);
break;
} }
sum = sqrt(sum);
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f; const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
out[i] = inp[i] * norm; out[i] = inp[i] * norm;
@ -2880,6 +2632,14 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
sum2 += embd2[i] * embd2[i]; sum2 += embd2[i] * embd2[i];
} }
// Handle the case where one or both vectors are zero vectors
if (sum1 == 0.0 || sum2 == 0.0) {
if (sum1 == 0.0 && sum2 == 0.0) {
return 1.0f; // two zero vectors are similar
}
return 0.0f;
}
return sum / (sqrt(sum1) * sqrt(sum2)); return sum / (sqrt(sum1) * sqrt(sum2));
} }

View file

@ -148,6 +148,9 @@ struct gpt_params {
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
bool embedding = false; // get only sentence embedding bool embedding = false; // get only sentence embedding
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [] or [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embendings
bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
bool interactive_first = false; // wait for user input immediately bool interactive_first = false; // wait for user input immediately
bool multiline_input = false; // reverse the usage of `\` bool multiline_input = false; // reverse the usage of `\`
@ -305,7 +308,7 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40
// Embedding utils // Embedding utils
// //
void llama_embd_normalize(const float * inp, float * out, int n); void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);

View file

@ -19,3 +19,43 @@ embedding.exe -m ./path/to/model --log-disable -p "Hello World!" 2>$null
``` ```
The above command will output space-separated float values. The above command will output space-separated float values.
## extra parameters
### --embd-normalize $integer$
| $integer$ | description | formula |
|-----------|---------------------|---------|
| $-1$ | none |
| $0$ | max absolute int16 | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$
| $1$ | taxicab | $\Large{x_i \over\sum \lvert x_i\rvert}$
| $2$ | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$
| $>2$ | p-norm | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$
### --embd-output-format $'string'$
| $'string'$ | description | |
|------------|------------------------------|--|
| '' | same as before | (default)
| 'array' | single embeddings | $[[x_1,...,x_n]]$
| | multiple embeddings | $[_0[x_1,...,x_n],_1[x_1,...,x_n],...,_{n-1}[x_1,...,x_n]]$
| 'json' | openai style |
| 'json+' | add cosine similarity matrix |
### --embd-separator $"string"$
| $"string"$ | |
|--------------|-|
| "\n" | (default)
| "<#embSep#>" | for exemple
| "<#sep#>" | other exemple
## examples
### Unix-based systems (Linux, macOS, etc.):
```bash
./embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
```
### Windows:
```powershell
embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
```

View file

@ -7,13 +7,19 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
static std::vector<std::string> split_lines(const std::string & s) { static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
std::string line;
std::vector<std::string> lines; std::vector<std::string> lines;
std::stringstream ss(s); size_t start = 0;
while (std::getline(ss, line)) { size_t end = s.find(separator);
lines.push_back(line);
while (end != std::string::npos) {
lines.push_back(s.substr(start, end - start));
start = end + separator.length();
end = s.find(separator, start);
} }
lines.push_back(s.substr(start)); // Add the last part
return lines; return lines;
} }
@ -23,7 +29,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
} }
} }
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
@ -49,13 +55,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
} }
float * out = output + batch.seq_id[i][0] * n_embd; float * out = output + batch.seq_id[i][0] * n_embd;
//TODO: I would also add a parameter here to enable normalization or not. llama_embd_normalize(embd, out, n_embd, embd_norm);
/*fprintf(stdout, "unnormalized_embedding:");
for (int hh = 0; hh < n_embd; hh++) {
fprintf(stdout, "%9.6f ", embd[hh]);
}
fprintf(stdout, "\n");*/
llama_embd_normalize(embd, out, n_embd);
} }
} }
@ -111,7 +111,7 @@ int main(int argc, char ** argv) {
} }
// split the prompt into lines // split the prompt into lines
std::vector<std::string> prompts = split_lines(params.prompt); std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
// max batch size // max batch size
const uint64_t n_batch = params.n_batch; const uint64_t n_batch = params.n_batch;
@ -171,7 +171,7 @@ int main(int argc, char ** argv) {
// encode if at capacity // encode if at capacity
if (batch.n_tokens + n_toks > n_batch) { if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd; float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd); batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
llama_batch_clear(batch); llama_batch_clear(batch);
p += s; p += s;
s = 0; s = 0;
@ -184,29 +184,72 @@ int main(int argc, char ** argv) {
// final batch // final batch
float * out = emb + p * n_embd; float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd); batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
// print the first part of the embeddings or for a single prompt, the full embedding if (params.embd_out=="") {
fprintf(stdout, "\n"); // print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
fprintf(stdout, "\n"); fprintf(stdout, "\n");
} for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
// print cosine similarity matrix for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
if (n_prompts > 1) { if (params.embd_normalize==0)
fprintf(stdout, "\n"); fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
printf("cosine similarity matrix:\n\n"); else
for (int i = 0; i < n_prompts; i++) { fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
} }
fprintf(stdout, "\n"); fprintf(stdout, "\n");
} }
// print cosine similarity matrix
if (n_prompts > 1) {
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
fprintf(stdout, "\n");
}
}
}
if (params.embd_out=="json" || params.embd_out=="json+" || params.embd_out=="array") {
const bool notArray = params.embd_out!="array";
fprintf(stdout, notArray?"{\n 'object': 'list',\n 'data': [\n":"[");
for (int j = 0;;) { // at least one iteration (one prompt)
if (notArray) fprintf(stdout, " {\n 'object': 'embedding',\n 'index': %d,\n 'embedding': ",j);
fprintf(stdout, "[");
for (int i = 0;;) { // at least one iteration (n_embd > 0)
fprintf(stdout, params.embd_normalize==0?"%1.0f":"%1.7f", emb[j * n_embd + i]);
i++;
if (i < n_embd) fprintf(stdout, ","); else break;
}
fprintf(stdout, notArray?"]\n }":"]");
j++;
if (j < n_prompts) fprintf(stdout, notArray?",\n":","); else break;
}
fprintf(stdout, notArray?"\n ]":"]\n");
if (params.embd_out=="json+" && n_prompts > 1) {
fprintf(stdout, ",\n cosineSimilarity: [\n");
for (int i = 0;;) { // at least two iteration (n_prompts > 1)
fprintf(stdout, " [");
for (int j = 0;;) { // at least two iteration (n_prompts > 1)
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f", sim);
j++;
if (j < n_prompts) fprintf(stdout, ", "); else break;
}
fprintf(stdout, " ]");
i++;
if (i < n_prompts) fprintf(stdout, ",\n"); else break;
}
fprintf(stdout, "\n ]");
}
if (notArray) fprintf(stdout, "\n}\n");
} }
// clean up // clean up