merge master
This commit is contained in:
commit
62dc17022b
60 changed files with 4698 additions and 966 deletions
8
.github/workflows/build.yml
vendored
8
.github/workflows/build.yml
vendored
|
@ -317,7 +317,7 @@ jobs:
|
|||
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
|
||||
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y build-essential vulkan-sdk
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
@ -327,6 +327,12 @@ jobs:
|
|||
cmake -DGGML_VULKAN=ON ..
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:6.0.2
|
||||
|
|
|
@ -221,7 +221,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
| [SYCL](docs/backend/SYCL.md) | Intel and Nvidia GPU |
|
||||
| [MUSA](docs/build.md#musa) | Moore Threads MTT GPU |
|
||||
| [CUDA](docs/build.md#cuda) | Nvidia GPU |
|
||||
| [hipBLAS](docs/build.md#hipblas) | AMD GPU |
|
||||
| [HIP](docs/build.md#hip) | AMD GPU |
|
||||
| [Vulkan](docs/build.md#vulkan) | GPU |
|
||||
| [CANN](docs/build.md#cann) | Ascend NPU |
|
||||
|
||||
|
@ -414,7 +414,7 @@ To learn more about model quantization, [read this documentation](examples/quant
|
|||
[^1]: [examples/perplexity/README.md](examples/perplexity/README.md)
|
||||
[^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity)
|
||||
|
||||
## [`llama-bench`](example/bench)
|
||||
## [`llama-bench`](examples/llama-bench)
|
||||
|
||||
#### Benchmark the performance of the inference for various parameters.
|
||||
|
||||
|
|
|
@ -119,29 +119,33 @@ std::string common_arg::to_string() {
|
|||
// utils
|
||||
//
|
||||
|
||||
static void common_params_handle_model_default(common_params & params) {
|
||||
if (!params.hf_repo.empty()) {
|
||||
static void common_params_handle_model_default(
|
||||
std::string & model,
|
||||
std::string & model_url,
|
||||
std::string & hf_repo,
|
||||
std::string & hf_file) {
|
||||
if (!hf_repo.empty()) {
|
||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||
if (params.hf_file.empty()) {
|
||||
if (params.model.empty()) {
|
||||
if (hf_file.empty()) {
|
||||
if (model.empty()) {
|
||||
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
|
||||
}
|
||||
params.hf_file = params.model;
|
||||
} else if (params.model.empty()) {
|
||||
hf_file = model;
|
||||
} else if (model.empty()) {
|
||||
// this is to avoid different repo having same file name, or same file name in different subdirs
|
||||
std::string filename = params.hf_repo + "_" + params.hf_file;
|
||||
std::string filename = hf_repo + "_" + hf_file;
|
||||
// to make sure we don't have any slashes in the filename
|
||||
string_replace_all(filename, "/", "_");
|
||||
params.model = fs_get_cache_file(filename);
|
||||
model = fs_get_cache_file(filename);
|
||||
}
|
||||
} else if (!params.model_url.empty()) {
|
||||
if (params.model.empty()) {
|
||||
auto f = string_split<std::string>(params.model_url, '#').front();
|
||||
} else if (!model_url.empty()) {
|
||||
if (model.empty()) {
|
||||
auto f = string_split<std::string>(model_url, '#').front();
|
||||
f = string_split<std::string>(f, '?').front();
|
||||
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
} else if (params.model.empty()) {
|
||||
params.model = DEFAULT_MODEL_PATH;
|
||||
} else if (model.empty()) {
|
||||
model = DEFAULT_MODEL_PATH;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -276,7 +280,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
|||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
common_params_handle_model_default(params);
|
||||
// TODO: refactor model params in a common struct
|
||||
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file);
|
||||
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);
|
||||
|
||||
if (params.escape) {
|
||||
string_process_escapes(params.prompt);
|
||||
|
@ -842,7 +848,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--sampling-seq"}, "SEQUENCE",
|
||||
{"--sampling-seq", "--sampler-seq"}, "SEQUENCE",
|
||||
string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.samplers = common_sampler_types_from_chars(value);
|
||||
|
@ -855,13 +861,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.sampling.ignore_eos = true;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--penalize-nl"},
|
||||
string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.sampling.penalize_nl = true;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--temp"}, "N",
|
||||
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
|
||||
|
@ -916,6 +915,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--repeat-last-n"}, "N",
|
||||
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
|
||||
[](common_params & params, int value) {
|
||||
if (value < -1) {
|
||||
throw std::runtime_error(string_format("error: invalid repeat-last-n = %d\n", value));
|
||||
}
|
||||
params.sampling.penalty_last_n = value;
|
||||
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
|
||||
}
|
||||
|
@ -970,6 +972,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"--dry-penalty-last-n"}, "N",
|
||||
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
|
||||
[](common_params & params, int value) {
|
||||
if (value < -1) {
|
||||
throw std::runtime_error(string_format("error: invalid dry-penalty-last-n = %d\n", value));
|
||||
}
|
||||
params.sampling.dry_penalty_last_n = value;
|
||||
}
|
||||
).set_sparam());
|
||||
|
@ -1582,6 +1587,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.hf_file = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HF_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"-hfrv", "--hf-repo-v"}, "REPO",
|
||||
"Hugging Face model repository for the vocoder model (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.vocoder.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HF_REPO_V"));
|
||||
add_opt(common_arg(
|
||||
{"-hffv", "--hf-file-v"}, "FILE",
|
||||
"Hugging Face model file for the vocoder model (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.vocoder.hf_file = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HF_FILE_V"));
|
||||
add_opt(common_arg(
|
||||
{"-hft", "--hf-token"}, "TOKEN",
|
||||
"Hugging Face access token (default: value from HF_TOKEN environment variable)",
|
||||
|
@ -2179,5 +2198,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"-mv", "--model-vocoder"}, "FNAME",
|
||||
"vocoder model for audio generation (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.vocoder.model = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
|
|
@ -940,6 +940,25 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
for (llama_token i = 0; i < llama_n_vocab(model); i++) {
|
||||
if (llama_token_is_eog(model, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
|
||||
|
@ -1076,7 +1095,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
|
|||
#define CURL_MAX_RETRY 3
|
||||
#define CURL_RETRY_DELAY_SECONDS 2
|
||||
|
||||
static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
|
||||
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
|
||||
int remaining_attempts = max_attempts;
|
||||
|
||||
while (remaining_attempts > 0) {
|
||||
|
@ -1100,7 +1119,6 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_
|
|||
}
|
||||
|
||||
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
|
||||
|
||||
// Initialize libcurl
|
||||
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
if (!curl) {
|
||||
|
@ -1173,11 +1191,13 @@ static bool common_download_file(const std::string & url, const std::string & pa
|
|||
std::string etag;
|
||||
std::string last_modified;
|
||||
};
|
||||
|
||||
common_load_model_from_url_headers headers;
|
||||
|
||||
{
|
||||
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
|
||||
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
|
||||
common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata;
|
||||
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
|
||||
|
||||
static std::regex header_regex("([^:]+): (.*)\r\n");
|
||||
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
||||
|
@ -1761,7 +1781,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
|
|||
break;
|
||||
case 0: // max absolute
|
||||
for (int i = 0; i < n; i++) {
|
||||
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
|
||||
if (sum < std::abs(inp[i])) {
|
||||
sum = std::abs(inp[i]);
|
||||
}
|
||||
}
|
||||
sum /= 32760.0; // make an int16 range
|
||||
break;
|
||||
|
|
|
@ -80,6 +80,7 @@ enum llama_example {
|
|||
LLAMA_EXAMPLE_LLAVA,
|
||||
LLAMA_EXAMPLE_LOOKUP,
|
||||
LLAMA_EXAMPLE_PARALLEL,
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
|
@ -95,6 +96,7 @@ enum common_sampler_type {
|
|||
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
||||
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||
};
|
||||
|
||||
// dimensionality reduction methods, used by cvector-generator
|
||||
|
@ -130,7 +132,6 @@ struct common_params_sampling {
|
|||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
bool ignore_eos = false;
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool timing_per_token = false;
|
||||
|
@ -139,6 +140,7 @@ struct common_params_sampling {
|
|||
|
||||
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||
COMMON_SAMPLER_TYPE_DRY,
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||
|
@ -158,6 +160,7 @@ struct common_params_sampling {
|
|||
|
||||
struct common_params_speculative {
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
|
||||
|
@ -171,6 +174,14 @@ struct common_params_speculative {
|
|||
std::string model = ""; // draft model for speculative decoding // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
|
||||
std::string model = ""; // model path // NOLINT
|
||||
std::string model_url = ""; // model url to download // NOLINT
|
||||
};
|
||||
|
||||
struct common_params {
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 4096; // context size
|
||||
|
@ -193,11 +204,13 @@ struct common_params {
|
|||
float defrag_thold = 0.1f; // KV cache defragmentation threshold
|
||||
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
|
||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
@ -211,8 +224,9 @@ struct common_params {
|
|||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
struct common_params_vocoder vocoder;
|
||||
|
||||
std::string model = ""; // model path // NOLINT
|
||||
std::string model_alias = ""; // model alias // NOLINT
|
||||
|
@ -593,7 +607,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
|
|||
// Embedding utils
|
||||
//
|
||||
|
||||
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
|
||||
// TODO: repace embd_norm with an enum
|
||||
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
|
||||
|
||||
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
|
||||
|
||||
|
|
|
@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_penalties(
|
||||
llama_n_vocab (model),
|
||||
llama_token_eos(model),
|
||||
llama_token_nl (model),
|
||||
params.penalty_last_n,
|
||||
params.penalty_repeat,
|
||||
params.penalty_freq,
|
||||
params.penalty_present,
|
||||
params.penalize_nl,
|
||||
params.ignore_eos));
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char*> c_breakers;
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto& str : params.dry_sequence_breakers) {
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
|
@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
|
@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
||||
default : return '?';
|
||||
}
|
||||
}
|
||||
|
@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
||||
default : return "";
|
||||
}
|
||||
}
|
||||
|
@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||
};
|
||||
|
||||
// since samplers names are written multiple ways
|
||||
|
@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
|
|
|
@ -221,17 +221,17 @@ class Model:
|
|||
self.gguf_writer.add_context_length(n_ctx)
|
||||
logger.info(f"gguf: context length = {n_ctx}")
|
||||
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
logger.info(f"gguf: embedding length = {n_embd}")
|
||||
if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None:
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
logger.info(f"gguf: embedding length = {n_embd}")
|
||||
|
||||
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
|
||||
self.gguf_writer.add_feed_forward_length(n_ff)
|
||||
logger.info(f"gguf: feed forward length = {n_ff}")
|
||||
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
logger.info(f"gguf: head count = {n_head}")
|
||||
if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None:
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
logger.info(f"gguf: head count = {n_head}")
|
||||
|
||||
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
|
@ -296,7 +296,9 @@ class Model:
|
|||
break
|
||||
|
||||
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
|
||||
data = data_torch.squeeze().numpy()
|
||||
# TODO: why do we squeeze here?
|
||||
# data = data_torch.squeeze().numpy()
|
||||
data = data_torch.numpy()
|
||||
|
||||
# if data ends up empty, it means data_torch was a scalar tensor -> restore
|
||||
if len(data.shape) == 0:
|
||||
|
@ -324,6 +326,8 @@ class Model:
|
|||
gguf.MODEL_TENSOR.TIME_MIX_W2,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
|
||||
gguf.MODEL_TENSOR.POSNET_NORM1,
|
||||
gguf.MODEL_TENSOR.POSNET_NORM2,
|
||||
)
|
||||
)
|
||||
or not new_name.endswith(".weight")
|
||||
|
@ -689,6 +693,9 @@ class Model:
|
|||
return res
|
||||
# Marker: End get_vocab_base_pre
|
||||
|
||||
def _set_vocab_none(self) -> None:
|
||||
self.gguf_writer.add_tokenizer_model("none")
|
||||
|
||||
def _set_vocab_gpt2(self) -> None:
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
|
@ -2027,6 +2034,44 @@ class Qwen2VLModel(Model):
|
|||
yield name, data
|
||||
|
||||
|
||||
@Model.register("WavTokenizerDec")
|
||||
class WavTokenizerDecModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if \
|
||||
name.endswith("codebook.cluster_size") or \
|
||||
name.endswith("codebook.embed_avg") or \
|
||||
name.endswith("codebook.inited"):
|
||||
logger.debug(f"Skipping {name!r}")
|
||||
return []
|
||||
|
||||
logger.info(f"{self.map_tensor_name(name)} -> {data_torch.shape}")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_none()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_vocab_size (self.hparams["vocab_size"])
|
||||
self.gguf_writer.add_features_length (self.hparams["n_embd_features"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["n_ff"])
|
||||
self.gguf_writer.add_group_norm_eps (self.hparams["group_norm_epsilon"])
|
||||
self.gguf_writer.add_group_norm_groups (self.hparams["group_norm_groups"])
|
||||
|
||||
self.gguf_writer.add_posnet_embedding_length(self.hparams["posnet"]["n_embd"])
|
||||
self.gguf_writer.add_posnet_block_count (self.hparams["posnet"]["n_layer"])
|
||||
|
||||
self.gguf_writer.add_convnext_embedding_length(self.hparams["convnext"]["n_embd"])
|
||||
self.gguf_writer.add_convnext_block_count (self.hparams["convnext"]["n_layer"])
|
||||
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
|
||||
|
||||
@Model.register("Qwen2MoeForCausalLM")
|
||||
class Qwen2MoeModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2MOE
|
||||
|
|
|
@ -51,6 +51,7 @@ else()
|
|||
add_subdirectory(speculative)
|
||||
add_subdirectory(speculative-simple)
|
||||
add_subdirectory(tokenize)
|
||||
add_subdirectory(tts)
|
||||
add_subdirectory(gen-docs)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
# these examples use the backends directly and cannot be built with dynamic loading
|
||||
|
|
|
@ -65,6 +65,7 @@ int main(int argc, char ** argv) {
|
|||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
}
|
||||
|
||||
std::vector<float> emb_norm(emb_unorm.size());
|
||||
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
|
||||
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
|
||||
result.push_back(emb_norm);
|
||||
|
||||
#ifdef GRIT_DEBUG
|
||||
|
|
|
@ -19,6 +19,7 @@ android {
|
|||
externalNativeBuild {
|
||||
cmake {
|
||||
arguments += "-DLLAMA_BUILD_COMMON=ON"
|
||||
arguments += "-DGGML_LLAMAFILE=OFF"
|
||||
arguments += "-DCMAKE_BUILD_TYPE=Release"
|
||||
cppFlags += listOf()
|
||||
arguments += listOf()
|
||||
|
|
|
@ -896,7 +896,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
|
||||
mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
|
||||
// stride = 1, padding = 1, bias is nullptr
|
||||
block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
|
||||
block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
|
||||
|
||||
// layer norm
|
||||
// // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
|
||||
|
@ -944,7 +944,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
// block_2
|
||||
{
|
||||
// stride = 2
|
||||
block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
|
||||
block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
|
||||
|
||||
// block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
|
||||
// layer norm
|
||||
|
@ -1005,7 +1005,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
// mlp_2 ne [24, 24, 2048, 1]
|
||||
mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
|
||||
// weight ne = [3, 3, 2048, 1]
|
||||
struct ggml_tensor * peg_0 = ggml_conv_depthwise_2d(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
|
||||
struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
|
||||
peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
|
||||
peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
|
||||
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
|
||||
|
|
|
@ -177,16 +177,11 @@ Example usage: `--temp 0`
|
|||
|
||||
- `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
|
||||
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
|
||||
- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
|
||||
|
||||
The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
|
||||
|
||||
The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
|
||||
|
||||
Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
|
||||
|
||||
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
|
||||
|
||||
### DRY Repetition Penalty
|
||||
|
||||
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).
|
||||
|
|
|
@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
}
|
||||
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
common_embd_normalize(embd, out, n_embd);
|
||||
common_embd_normalize(embd, out, n_embd, 2);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -104,7 +104,6 @@ The project is under active development, and we are [looking for feedback and co
|
|||
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
|
||||
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
|
||||
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
|
||||
| `--penalize-nl` | penalize newline tokens (default: false) |
|
||||
| `--temp N` | temperature (default: 0.8) |
|
||||
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
|
||||
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
|
||||
|
@ -393,8 +392,6 @@ These words will not be included in the completion, so make sure to add them to
|
|||
|
||||
`repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
|
||||
|
||||
`penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true`
|
||||
|
||||
`presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled.
|
||||
|
||||
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
||||
|
@ -441,19 +438,22 @@ These words will not be included in the completion, so make sure to add them to
|
|||
|
||||
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `true`
|
||||
|
||||
`return_tokens`: Return the raw generated token ids in the `tokens` field. Otherwise `tokens` remains empty. Default: `false`
|
||||
|
||||
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values.
|
||||
|
||||
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
|
||||
|
||||
**Response format**
|
||||
|
||||
- Note: In streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||
|
||||
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "<the token selected by the model>",
|
||||
"content": "<the token generated by the model>",
|
||||
"tokens": [ generated token ids if requested ],
|
||||
"probs": [
|
||||
{
|
||||
"prob": float,
|
||||
|
@ -471,6 +471,7 @@ These words will not be included in the completion, so make sure to add them to
|
|||
Notice that each `probs` is an array of length `n_probs`.
|
||||
|
||||
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
|
||||
- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.
|
||||
- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
|
||||
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
|
||||
- `model`: The path to the model loaded with `-m`
|
||||
|
@ -655,7 +656,6 @@ This endpoint is public (no API key check). By default, it is read-only. To make
|
|||
"mirostat": 0,
|
||||
"mirostat_tau": 5.0,
|
||||
"mirostat_eta": 0.10000000149011612,
|
||||
"penalize_nl": false,
|
||||
"stop": [],
|
||||
"max_tokens": -1,
|
||||
"n_keep": 0,
|
||||
|
@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \
|
|||
|
||||
### POST `/v1/embeddings`: OpenAI-compatible embeddings API
|
||||
|
||||
This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
|
||||
|
||||
*Options:*
|
||||
|
||||
See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
|
||||
|
@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
|
|||
}'
|
||||
```
|
||||
|
||||
### POST `/embeddings`: non-OpenAI-compatible embeddings API
|
||||
|
||||
This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.
|
||||
|
||||
Note that the response format of this endpoint is different from `/v1/embeddings`.
|
||||
|
||||
*Options:*
|
||||
|
||||
Same as the `/v1/embeddings` endpoint.
|
||||
|
||||
*Examples:*
|
||||
|
||||
Same as the `/v1/embeddings` endpoint.
|
||||
|
||||
**Response format**
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
[ ... embeddings for token 0 ... ],
|
||||
[ ... embeddings for token 1 ... ],
|
||||
[ ... ]
|
||||
[ ... embeddings for token N-1 ... ],
|
||||
]
|
||||
},
|
||||
...
|
||||
{
|
||||
"index": P,
|
||||
"embedding": [
|
||||
[ ... embeddings for token 0 ... ],
|
||||
[ ... embeddings for token 1 ... ],
|
||||
[ ... ]
|
||||
[ ... embeddings for token N-1 ... ],
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### GET `/slots`: Returns the current slots processing state
|
||||
|
||||
> [!WARNING]
|
||||
|
@ -845,7 +887,6 @@ Example:
|
|||
"mirostat": 0,
|
||||
"mirostat_tau": 5.0,
|
||||
"mirostat_eta": 0.10000000149011612,
|
||||
"penalize_nl": false,
|
||||
"stop": [],
|
||||
"max_tokens": -1,
|
||||
"n_keep": 0,
|
||||
|
|
Binary file not shown.
|
@ -39,7 +39,6 @@
|
|||
temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower
|
||||
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
|
||||
repeat_penalty: 1.0, // 1.0 = disabled
|
||||
penalize_nl: false, // true only useful for infinite completion
|
||||
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
||||
dry_base: 1.75, // 0.0 = disabled
|
||||
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||
|
|
|
@ -303,7 +303,6 @@
|
|||
temperature: 0.7,
|
||||
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||
repeat_penalty: 1.18, // 1.0 = disabled
|
||||
penalize_nl: false,
|
||||
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
||||
dry_base: 1.75, // 0.0 = disabled
|
||||
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||
|
@ -1006,7 +1005,6 @@
|
|||
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
||||
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
||||
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
|
||||
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
||||
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
||||
|
|
|
@ -79,8 +79,9 @@ enum error_type {
|
|||
};
|
||||
|
||||
struct slot_params {
|
||||
bool stream = true;
|
||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||
bool stream = true;
|
||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||
bool return_tokens = false;
|
||||
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
|
@ -135,7 +136,6 @@ struct slot_params {
|
|||
{"mirostat", sampling.mirostat},
|
||||
{"mirostat_tau", sampling.mirostat_tau},
|
||||
{"mirostat_eta", sampling.mirostat_eta},
|
||||
{"penalize_nl", sampling.penalize_nl},
|
||||
{"stop", antiprompt},
|
||||
{"max_tokens", n_predict}, // User configured n_predict
|
||||
{"n_keep", n_keep},
|
||||
|
@ -184,6 +184,7 @@ struct server_task {
|
|||
|
||||
static slot_params params_from_json_cmpl(
|
||||
const llama_model * model,
|
||||
const llama_context * ctx,
|
||||
const common_params & params_base,
|
||||
const json & data) {
|
||||
slot_params params;
|
||||
|
@ -199,6 +200,7 @@ struct server_task {
|
|||
|
||||
params.stream = json_value(data, "stream", false);
|
||||
params.cache_prompt = json_value(data, "cache_prompt", true);
|
||||
params.return_tokens = json_value(data, "return_tokens", false);
|
||||
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
||||
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
|
@ -226,7 +228,6 @@ struct server_task {
|
|||
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
||||
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
||||
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
||||
params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
|
||||
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||
|
@ -239,8 +240,27 @@ struct server_task {
|
|||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||
|
||||
// TODO: add more sanity checks for the input parameters
|
||||
|
||||
if (params.sampling.penalty_last_n < -1) {
|
||||
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n < -1) {
|
||||
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
// note: should be the slot's context and not the full context, but it's ok
|
||||
params.sampling.penalty_last_n = llama_n_ctx(ctx);
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n == -1) {
|
||||
params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
|
||||
}
|
||||
|
||||
if (params.sampling.dry_base < 1.0f) {
|
||||
params.sampling.dry_base = defaults.sampling.dry_base;
|
||||
params.sampling.dry_base = defaults.sampling.dry_base;
|
||||
}
|
||||
|
||||
// sequence breakers for DRY
|
||||
|
@ -450,7 +470,10 @@ struct completion_token_output {
|
|||
|
||||
struct server_task_result_cmpl_final : server_task_result {
|
||||
int index = 0;
|
||||
std::string content;
|
||||
|
||||
std::string content;
|
||||
llama_tokens tokens;
|
||||
|
||||
bool stream;
|
||||
result_timings timings;
|
||||
std::string prompt;
|
||||
|
@ -492,6 +515,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json res = json {
|
||||
{"index", index},
|
||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"tokens", stream ? llama_tokens {} : tokens},
|
||||
{"id_slot", id_slot},
|
||||
{"stop", true},
|
||||
{"model", oaicompat_model},
|
||||
|
@ -521,9 +545,9 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json choices = json::array({json{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{
|
||||
{"message", json {
|
||||
{"content", content},
|
||||
{"role", "assistant"}
|
||||
{"role", "assistant"}
|
||||
}
|
||||
}}});
|
||||
|
||||
|
@ -587,7 +611,9 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
struct server_task_result_cmpl_partial : server_task_result {
|
||||
int index = 0;
|
||||
std::string content;
|
||||
|
||||
std::string content;
|
||||
llama_tokens tokens;
|
||||
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
|
@ -619,6 +645,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
json res = json {
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"tokens", tokens},
|
||||
{"stop", false},
|
||||
{"id_slot", id_slot},
|
||||
{"tokens_predicted", n_decoded},
|
||||
|
@ -660,7 +687,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
json second_ret = json{
|
||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{
|
||||
{"delta", json {
|
||||
{"content", content}}}
|
||||
}})},
|
||||
{"created", t},
|
||||
|
@ -675,7 +702,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta",
|
||||
json{
|
||||
json {
|
||||
{"content", content},
|
||||
}},
|
||||
}});
|
||||
|
@ -699,32 +726,52 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
|
||||
struct server_task_result_embd : server_task_result {
|
||||
int index = 0;
|
||||
std::vector<float> embedding;
|
||||
std::vector<std::vector<float>> embedding;
|
||||
|
||||
int32_t n_tokens;
|
||||
|
||||
// OAI-compat fields
|
||||
bool oaicompat = false;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat() {
|
||||
return json {
|
||||
{"index", index},
|
||||
{"embedding", embedding},
|
||||
};
|
||||
}
|
||||
|
||||
json to_json_oaicompat() {
|
||||
return json {
|
||||
{"index", index},
|
||||
{"embedding", embedding[0]},
|
||||
{"tokens_evaluated", n_tokens},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_rerank : server_task_result {
|
||||
int index = 0;
|
||||
float score = -1e6;
|
||||
|
||||
int32_t n_tokens;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
return json {
|
||||
{"index", index},
|
||||
{"score", score},
|
||||
{"index", index},
|
||||
{"score", score},
|
||||
{"tokens_evaluated", n_tokens},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
@ -931,8 +978,11 @@ struct server_slot {
|
|||
|
||||
size_t last_nl_pos = 0;
|
||||
|
||||
std::string generated_text;
|
||||
std::string generated_text;
|
||||
llama_tokens generated_tokens;
|
||||
|
||||
llama_tokens cache_tokens;
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
bool has_next_token = true;
|
||||
|
@ -976,6 +1026,7 @@ struct server_slot {
|
|||
n_sent_token_probs = 0;
|
||||
task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
generated_tokens.clear();
|
||||
generated_token_probs.clear();
|
||||
}
|
||||
|
||||
|
@ -1469,7 +1520,7 @@ struct server_context {
|
|||
n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
add_bos_token = llama_add_bos_token(model);
|
||||
has_eos_token = !llama_add_eos_token(model);
|
||||
has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
|
||||
|
||||
if (!params_base.speculative.model.empty()) {
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
||||
|
@ -1716,8 +1767,10 @@ struct server_context {
|
|||
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
slot.sampled = result.tok;
|
||||
|
||||
// search stop word and delete it
|
||||
slot.generated_text += token_str;
|
||||
if (slot.params.return_tokens) {
|
||||
slot.generated_tokens.push_back(result.tok);
|
||||
}
|
||||
slot.has_next_token = true;
|
||||
|
||||
// check if there is incomplete UTF-8 character at the end
|
||||
|
@ -1742,6 +1795,7 @@ struct server_context {
|
|||
break;
|
||||
}
|
||||
|
||||
// search stop word and delete it
|
||||
if (!incomplete) {
|
||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||
|
||||
|
@ -1894,6 +1948,7 @@ struct server_context {
|
|||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
res->content = tkn.text_to_send;
|
||||
res->tokens = { tkn.tok };
|
||||
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
|
@ -1934,6 +1989,7 @@ struct server_context {
|
|||
|
||||
res->index = slot.index;
|
||||
res->content = slot.generated_text;
|
||||
res->tokens = slot.generated_tokens;
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||
|
||||
|
@ -1975,8 +2031,10 @@ struct server_context {
|
|||
|
||||
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
||||
auto res = std::make_unique<server_task_result_embd>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
res->n_tokens = slot.n_prompt_tokens;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
|
||||
const int n_embd = llama_n_embd(model);
|
||||
|
||||
|
@ -1995,12 +2053,18 @@ struct server_context {
|
|||
if (embd == NULL) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||
|
||||
res->embedding = std::vector<float>(n_embd, 0.0f);
|
||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||
continue;
|
||||
}
|
||||
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd);
|
||||
res->embedding = embd_res;
|
||||
// normalize only when there is pooling
|
||||
// TODO: configurable
|
||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
||||
res->embedding.push_back(embd_res);
|
||||
} else {
|
||||
res->embedding.push_back({ embd, embd + n_embd });
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "%s", "sending embeddings\n");
|
||||
|
@ -2012,6 +2076,7 @@ struct server_context {
|
|||
auto res = std::make_unique<server_task_result_rerank>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
res->n_tokens = slot.n_prompt_tokens;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
|
@ -2613,7 +2678,10 @@ struct server_context {
|
|||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
||||
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
|
@ -3381,7 +3449,7 @@ int main(int argc, char ** argv) {
|
|||
task.index = i;
|
||||
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
|
||||
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
|
@ -3621,34 +3689,50 @@ int main(int argc, char ** argv) {
|
|||
res_ok(res, data);
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
|
||||
const json body = json::parse(req.body);
|
||||
bool oaicompat = false;
|
||||
|
||||
// an input prompt can be a string or a list of tokens (integer)
|
||||
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// for the shape of input/content, see tokenize_input_prompts()
|
||||
json prompt;
|
||||
if (body.count("input") != 0) {
|
||||
oaicompat = true;
|
||||
prompt = body.at("input");
|
||||
} else if (body.count("content") != 0) {
|
||||
// with "content", we only support single prompt
|
||||
prompt = std::vector<std::string>{body.at("content")};
|
||||
} else if (body.contains("content")) {
|
||||
oaicompat = false;
|
||||
prompt = body.at("content");
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
||||
for (const auto & tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
|
||||
tasks.push_back(task);
|
||||
}
|
||||
|
||||
|
@ -3676,12 +3760,18 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// write JSON response
|
||||
json root = oaicompat
|
||||
? format_embeddings_response_oaicompat(body, responses)
|
||||
: responses.size() == 1 ? responses[0] : json(responses);
|
||||
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
|
||||
res_ok(res, root);
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
handle_embeddings_impl(req, res, false);
|
||||
};
|
||||
|
||||
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
handle_embeddings_impl(req, res, true);
|
||||
};
|
||||
|
||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
|
@ -3855,7 +3945,7 @@ int main(int argc, char ** argv) {
|
|||
svr->Post("/infill", handle_infill);
|
||||
svr->Post("/embedding", handle_embeddings); // legacy
|
||||
svr->Post("/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
||||
svr->Post("/rerank", handle_rerank);
|
||||
svr->Post("/reranking", handle_rerank);
|
||||
svr->Post("/v1/rerank", handle_rerank);
|
||||
|
|
|
@ -10,16 +10,17 @@ def create_server():
|
|||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
|
||||
])
|
||||
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": prompt,
|
||||
"return_tokens": return_tokens,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["prompt_n"] == n_prompt
|
||||
|
@ -27,6 +28,11 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
|
|||
assert res.body["truncated"] == truncated
|
||||
assert type(res.body["has_new_line"]) == bool
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
if return_tokens:
|
||||
assert len(res.body["tokens"]) > 0
|
||||
assert all(type(tok) == int for tok in res.body["tokens"])
|
||||
else:
|
||||
assert res.body["tokens"] == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
|
@ -56,6 +62,8 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
|||
assert data["generation_settings"]["seed"] == server.seed
|
||||
assert match_regex(re_content, content)
|
||||
else:
|
||||
assert len(data["tokens"]) > 0
|
||||
assert all(type(tok) == int for tok in data["tokens"])
|
||||
content += data["content"]
|
||||
|
||||
|
||||
|
|
|
@ -14,8 +14,9 @@ def create_server():
|
|||
|
||||
def test_embedding_single():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "I believe the meaning of life is",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
@ -29,8 +30,9 @@ def test_embedding_single():
|
|||
|
||||
def test_embedding_multiple():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||
|
@ -45,10 +47,69 @@ def test_embedding_multiple():
|
|||
assert len(d['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_openai_library_single():
|
||||
@pytest.mark.parametrize(
|
||||
"input,is_multi_prompt",
|
||||
[
|
||||
# single prompt
|
||||
("string", False),
|
||||
([12, 34, 56], False),
|
||||
([12, 34, "string", 56, 78], False),
|
||||
# multiple prompts
|
||||
(["string1", "string2"], True),
|
||||
(["string1", [12, 34, 56]], True),
|
||||
([[12, 34, 56], [12, 34, 56]], True),
|
||||
([[12, 34, 56], [12, "string", 34, 56]], True),
|
||||
]
|
||||
)
|
||||
def test_embedding_mixed_input(input, is_multi_prompt: bool):
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
|
||||
assert res.status_code == 200
|
||||
data = res.body['data']
|
||||
if is_multi_prompt:
|
||||
assert len(data) == len(input)
|
||||
for d in data:
|
||||
assert 'embedding' in d
|
||||
assert len(d['embedding']) > 1
|
||||
else:
|
||||
assert 'embedding' in data[0]
|
||||
assert len(data[0]['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_pooling_none():
|
||||
global server
|
||||
server.pooling = 'none'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": "hello hello hello",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert 'embedding' in res.body[0]
|
||||
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
|
||||
|
||||
# make sure embedding vector is not normalized
|
||||
for x in res.body[0]['embedding']:
|
||||
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
|
||||
|
||||
|
||||
def test_embedding_pooling_none_oai():
|
||||
global server
|
||||
server.pooling = 'none'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "hello hello hello",
|
||||
})
|
||||
|
||||
# /v1/embeddings does not support pooling type 'none'
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_embedding_openai_library_single():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
||||
assert len(res.data) == 1
|
||||
assert len(res.data[0].embedding) > 1
|
||||
|
@ -56,8 +117,9 @@ def test_embedding_openai_library_single():
|
|||
|
||||
def test_embedding_openai_library_multiple():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
||||
"I believe the meaning of life is",
|
||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||
|
@ -71,8 +133,9 @@ def test_embedding_openai_library_multiple():
|
|||
|
||||
def test_embedding_error_prompt_too_long():
|
||||
global server
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": "This is a test " * 512,
|
||||
})
|
||||
assert res.status_code != 200
|
||||
|
@ -80,8 +143,9 @@ def test_embedding_error_prompt_too_long():
|
|||
|
||||
|
||||
def test_same_prompt_give_same_result():
|
||||
server.pooling = 'last'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
|
@ -97,3 +161,33 @@ def test_same_prompt_give_same_result():
|
|||
vi = res.body['data'][i]['embedding']
|
||||
for x, y in zip(v0, vi):
|
||||
assert abs(x - y) < EPSILON
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content,n_tokens",
|
||||
[
|
||||
("I believe the meaning of life is", 9),
|
||||
("This is a test", 6),
|
||||
]
|
||||
)
|
||||
def test_embedding_usage_single(content, n_tokens):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||
|
||||
|
||||
def test_embedding_usage_multiple():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/v1/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == 2 * 9
|
||||
|
|
|
@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents):
|
|||
})
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,doc1,doc2,n_tokens",
|
||||
[
|
||||
("Machine learning is", "A machine", "Learning is", 19),
|
||||
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
|
||||
]
|
||||
)
|
||||
def test_rerank_usage(query, doc1, doc2, n_tokens):
|
||||
global server
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": query,
|
||||
"documents": [
|
||||
doc1,
|
||||
doc2,
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||
|
|
|
@ -65,6 +65,7 @@ class ServerProcess:
|
|||
server_reranking: bool | None = False
|
||||
server_metrics: bool | None = False
|
||||
server_slots: bool | None = False
|
||||
pooling: str | None = None
|
||||
draft: int | None = None
|
||||
api_key: str | None = None
|
||||
response_format: str | None = None
|
||||
|
@ -132,6 +133,8 @@ class ServerProcess:
|
|||
server_args.append("--metrics")
|
||||
if self.server_slots:
|
||||
server_args.append("--slots")
|
||||
if self.pooling:
|
||||
server_args.extend(["--pooling", self.pooling])
|
||||
if self.model_alias:
|
||||
server_args.extend(["--alias", self.model_alias])
|
||||
if self.n_ctx:
|
||||
|
|
|
@ -222,7 +222,6 @@
|
|||
temperature: 0.7,
|
||||
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||
repeat_penalty: 1.18, // 1.0 = disabled
|
||||
penalize_nl: false,
|
||||
top_k: 40, // <= 0 to use vocab size
|
||||
top_p: 0.95, // 1.0 = disabled
|
||||
min_p: 0.05, // 0 = disabled
|
||||
|
@ -779,7 +778,6 @@
|
|||
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
||||
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
||||
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
|
||||
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
||||
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
||||
|
|
|
@ -225,7 +225,6 @@
|
|||
temperature: 0.7,
|
||||
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||
repeat_penalty: 1.18, // 1.0 = disabled
|
||||
penalize_nl: false,
|
||||
top_k: 40, // <= 0 to use vocab size
|
||||
top_p: 0.95, // 1.0 = disabled
|
||||
min_p: 0.05, // 0 = disabled
|
||||
|
@ -782,7 +781,6 @@
|
|||
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
||||
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
||||
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
|
||||
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
||||
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
||||
|
|
|
@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
|
|||
* and multiple prompts (multi-tasks):
|
||||
* - "prompt": ["string1", "string2"]
|
||||
* - "prompt": ["string1", [12, 34, 56]]
|
||||
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
||||
*/
|
||||
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||
|
@ -560,6 +561,7 @@ static json oaicompat_completion_params_parse(
|
|||
|
||||
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
||||
json data = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
int i = 0;
|
||||
for (const auto & elem : embeddings) {
|
||||
data.push_back(json{
|
||||
|
@ -567,14 +569,16 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||
{"index", i++},
|
||||
{"object", "embedding"}
|
||||
});
|
||||
|
||||
n_tokens += json_value(elem, "tokens_evaluated", 0);
|
||||
}
|
||||
|
||||
json res = json {
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json { // TODO: fill
|
||||
{"prompt_tokens", 0},
|
||||
{"total_tokens", 0}
|
||||
{"usage", json {
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"data", data}
|
||||
};
|
||||
|
@ -584,20 +588,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||
|
||||
static json format_response_rerank(const json & request, const json & ranks) {
|
||||
json data = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
int i = 0;
|
||||
for (const auto & rank : ranks) {
|
||||
data.push_back(json{
|
||||
{"index", i++},
|
||||
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||
});
|
||||
|
||||
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||
}
|
||||
|
||||
json res = json {
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json { // TODO: fill
|
||||
{"prompt_tokens", 0},
|
||||
{"total_tokens", 0}
|
||||
{"usage", json {
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"results", data}
|
||||
};
|
||||
|
|
7
examples/server/webui/package-lock.json
generated
7
examples/server/webui/package-lock.json
generated
|
@ -8,6 +8,7 @@
|
|||
"name": "webui",
|
||||
"version": "0.0.0",
|
||||
"dependencies": {
|
||||
"@sec-ant/readable-stream": "^0.6.0",
|
||||
"@vscode/markdown-it-katex": "^1.1.1",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"daisyui": "^4.12.14",
|
||||
|
@ -617,6 +618,12 @@
|
|||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@sec-ant/readable-stream": {
|
||||
"version": "0.6.0",
|
||||
"resolved": "https://registry.npmjs.org/@sec-ant/readable-stream/-/readable-stream-0.6.0.tgz",
|
||||
"integrity": "sha512-uiBh8DrB5FN35gP6/o8JEhEQ7/ci1jUsOZO/VMUjyvTpjtV54VstOXVj1TvTj/wsT23pfX6butxxh3qufsW3+g==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@vscode/markdown-it-katex": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@vscode/markdown-it-katex/-/markdown-it-katex-1.1.1.tgz",
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"vite": "^5.4.10"
|
||||
},
|
||||
"dependencies": {
|
||||
"@sec-ant/readable-stream": "^0.6.0",
|
||||
"@vscode/markdown-it-katex": "^1.1.1",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"daisyui": "^4.12.14",
|
||||
|
|
|
@ -5,13 +5,16 @@ import TextLineStream from 'textlinestream';
|
|||
|
||||
// math formula rendering
|
||||
import 'katex/dist/katex.min.css';
|
||||
import markdownItKatexGpt, { renderLatexHTML } from './katex-gpt';
|
||||
import markdownItKatexGpt from './katex-gpt';
|
||||
import markdownItKatexNormal from '@vscode/markdown-it-katex';
|
||||
|
||||
// code highlighting
|
||||
import hljs from './highlight-config';
|
||||
import daisyuiThemes from 'daisyui/src/theming/themes';
|
||||
|
||||
// ponyfill for missing ReadableStream asyncIterator on Safari
|
||||
import { asyncIterator } from "@sec-ant/readable-stream/ponyfill/asyncIterator";
|
||||
|
||||
const isDev = import.meta.env.MODE === 'development';
|
||||
|
||||
// utility functions
|
||||
|
@ -33,7 +36,7 @@ const CONFIG_DEFAULT = {
|
|||
systemMessage: 'You are a helpful assistant.',
|
||||
showTokensPerSecond: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'dkypmxt',
|
||||
samplers: 'edkypmxt',
|
||||
temperature: 0.8,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
|
@ -283,7 +286,7 @@ async function* sendSSEPostRequest(url, fetchOptions) {
|
|||
const lines = res.body
|
||||
.pipeThrough(new TextDecoderStream())
|
||||
.pipeThrough(new TextLineStream());
|
||||
for await (const line of lines) {
|
||||
for await (const line of asyncIterator(lines)) {
|
||||
if (isDev) console.log({line});
|
||||
if (line.startsWith('data:') && !line.endsWith('[DONE]')) {
|
||||
const data = JSON.parse(line.slice(5));
|
||||
|
@ -442,7 +445,7 @@ const mainApp = createApp({
|
|||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': this.config.apiKey ? `Bearer ${this.config.apiKey}` : undefined,
|
||||
...(this.config.apiKey ? {'Authorization': `Bearer ${this.config.apiKey}`} : {})
|
||||
},
|
||||
body: JSON.stringify(params),
|
||||
signal: abortController.signal,
|
||||
|
|
5
examples/tts/CMakeLists.txt
Normal file
5
examples/tts/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
set(TARGET llama-tts)
|
||||
add_executable(${TARGET} tts.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
180
examples/tts/convert_pt_to_hf.py
Normal file
180
examples/tts/convert_pt_to_hf.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# convert the https://huggingface.co/novateur/WavTokenizer-large-speech-75token to HF format
|
||||
# the goal is to be able to reuse the convert_hf_to_gguf.py after that to create a GGUF file with the WavTokenizer decoder
|
||||
#
|
||||
# TODO: this script is LLM-generated and probably very inefficient and should be rewritten
|
||||
|
||||
import torch
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# default
|
||||
model_path = './model.pt';
|
||||
|
||||
# read from CLI
|
||||
if len(sys.argv) > 1:
|
||||
model_path = sys.argv[1]
|
||||
|
||||
# get the directory of the input model
|
||||
path_dst = os.path.dirname(model_path)
|
||||
|
||||
print(f"Loading model from {model_path}")
|
||||
|
||||
model = torch.load(model_path, map_location='cpu')
|
||||
|
||||
#print(model)
|
||||
|
||||
# print all keys
|
||||
for key in model.keys():
|
||||
print(key)
|
||||
if key == 'hyper_parameters':
|
||||
#print(model[key])
|
||||
# dump as json pretty
|
||||
print(json.dumps(model[key], indent=4))
|
||||
#if key != 'state_dict' and key != 'optimizer_states':
|
||||
# print(model[key])
|
||||
|
||||
# Check if the loaded model is a state_dict or a model instance
|
||||
if isinstance(model, torch.nn.Module):
|
||||
state_dict = model.state_dict()
|
||||
else:
|
||||
state_dict = model
|
||||
|
||||
# Print the structure of the state_dict to understand its format
|
||||
print("State dictionary keys:")
|
||||
for key in state_dict.keys():
|
||||
print(key)
|
||||
|
||||
# Ensure the state_dict is flat and contains only torch.Tensor objects
|
||||
def flatten_state_dict(state_dict, parent_key='', sep='.'):
|
||||
items = []
|
||||
items_new = []
|
||||
|
||||
for k, v in state_dict.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, torch.Tensor):
|
||||
items.append((new_key, v))
|
||||
elif isinstance(v, dict):
|
||||
items.extend(flatten_state_dict(v, new_key, sep=sep).items())
|
||||
return dict(items)
|
||||
|
||||
size_total_mb = 0
|
||||
|
||||
for key, value in list(items):
|
||||
# keep only what we need for inference
|
||||
if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
|
||||
not key.startswith('state_dict.backbone.') and \
|
||||
not key.startswith('state_dict.head.out'):
|
||||
print('Skipping key: ', key)
|
||||
continue
|
||||
|
||||
new_key = key
|
||||
|
||||
new_key = new_key.replace('state_dict.', '')
|
||||
new_key = new_key.replace('pos_net', 'posnet')
|
||||
|
||||
# check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight"
|
||||
if new_key.startswith("backbone.posnet."):
|
||||
match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key)
|
||||
if match:
|
||||
new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}"
|
||||
|
||||
# "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight"
|
||||
if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
|
||||
new_key = "backbone.embedding.weight"
|
||||
|
||||
# these are the only rows used
|
||||
# ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100
|
||||
if new_key.endswith("norm.scale.weight"):
|
||||
new_key = new_key.replace("norm.scale.weight", "norm.weight")
|
||||
value = value[0]
|
||||
|
||||
if new_key.endswith("norm.shift.weight"):
|
||||
new_key = new_key.replace("norm.shift.weight", "norm.bias")
|
||||
value = value[0]
|
||||
|
||||
if new_key.endswith("gamma"):
|
||||
new_key = new_key.replace("gamma", "gamma.weight")
|
||||
|
||||
# convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias
|
||||
if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")):
|
||||
value = value.unsqueeze(1)
|
||||
|
||||
if new_key.endswith("dwconv.bias"):
|
||||
value = value.unsqueeze(1)
|
||||
|
||||
size_mb = value.element_size() * value.nelement() / (1024 * 1024)
|
||||
print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
|
||||
|
||||
size_total_mb += size_mb
|
||||
|
||||
#print(key, '->', new_key, ': ', value)
|
||||
#print(key, '->', new_key)
|
||||
|
||||
items_new.append((new_key, value))
|
||||
|
||||
print(f"Total size: {size_total_mb:8.2f} MB")
|
||||
|
||||
return dict(items_new)
|
||||
|
||||
flattened_state_dict = flatten_state_dict(state_dict)
|
||||
|
||||
|
||||
# Convert the model to the safetensors format
|
||||
output_path = path_dst + '/model.safetensors'
|
||||
save_file(flattened_state_dict, output_path)
|
||||
|
||||
print(f"Model has been successfully converted and saved to {output_path}")
|
||||
|
||||
# Calculate the total size of the .safetensors file
|
||||
total_size = os.path.getsize(output_path)
|
||||
|
||||
# Create the weight map
|
||||
weight_map = {
|
||||
"model.safetensors": ["*"] # Assuming all weights are in one file
|
||||
}
|
||||
|
||||
# Create metadata for the index.json file
|
||||
metadata = {
|
||||
"total_size": total_size,
|
||||
"weight_map": weight_map
|
||||
}
|
||||
|
||||
# Save the metadata to index.json
|
||||
index_path = path_dst + '/index.json'
|
||||
with open(index_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=4)
|
||||
|
||||
print(f"Metadata has been saved to {index_path}")
|
||||
|
||||
config = {
|
||||
"architectures": [
|
||||
"WavTokenizerDec"
|
||||
],
|
||||
"hidden_size": 1282,
|
||||
"n_embd_features": 512,
|
||||
"n_ff": 2304,
|
||||
"vocab_size": 4096,
|
||||
"n_head": 1,
|
||||
"layer_norm_epsilon": 1e-6,
|
||||
"group_norm_epsilon": 1e-6,
|
||||
"group_norm_groups": 32,
|
||||
"max_position_embeddings": 8192, # ?
|
||||
"n_layer": 12,
|
||||
"posnet": {
|
||||
"n_embd": 768,
|
||||
"n_layer": 6
|
||||
},
|
||||
"convnext": {
|
||||
"n_embd": 768,
|
||||
"n_layer": 12
|
||||
},
|
||||
}
|
||||
|
||||
with open(path_dst + '/config.json', 'w') as f:
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
print(f"Config has been saved to {path_dst + 'config.json'}")
|
175
examples/tts/tts-outetts.py
Normal file
175
examples/tts/tts-outetts.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
import sys
|
||||
#import json
|
||||
#import struct
|
||||
import requests
|
||||
import re
|
||||
|
||||
def process_text(text: str):
|
||||
text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed
|
||||
text = re.sub(r'[-_/,\.\\]', ' ', text)
|
||||
text = re.sub(r'[^a-z\s]', '', text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text.split()
|
||||
|
||||
# usage:
|
||||
# python tts-outetts.py http://server-llm:port http://server-dec:port "text"
|
||||
|
||||
if len(sys.argv) <= 3:
|
||||
print("usage: python tts-outetts.py http://server-llm:port http://server-dec:port \"text\"")
|
||||
exit(1)
|
||||
|
||||
host_llm = sys.argv[1]
|
||||
host_dec = sys.argv[2]
|
||||
text = sys.argv[3]
|
||||
|
||||
prefix = """<|im_start|>
|
||||
<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>"""
|
||||
|
||||
words = process_text(text)
|
||||
words = "<|text_sep|>".join([i.strip() for i in words])
|
||||
words += "<|text_end|>\n"
|
||||
|
||||
# voice data
|
||||
# TODO: load from json
|
||||
#suffix = """<|audio_start|>
|
||||
#the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
|
||||
#overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
|
||||
#package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
|
||||
#from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|>
|
||||
#just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|>
|
||||
#two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|>
|
||||
#people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|>
|
||||
#is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|>
|
||||
#pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|>
|
||||
#remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|>
|
||||
#sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|>
|
||||
#i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|>
|
||||
#have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|>
|
||||
#some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|>
|
||||
#critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|>
|
||||
#about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|>
|
||||
#some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|>
|
||||
#of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|>
|
||||
#the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|>
|
||||
#gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|>
|
||||
#aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|>
|
||||
#but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|>
|
||||
#its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|>
|
||||
#still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|>
|
||||
#really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|>
|
||||
#enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|>
|
||||
#and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|>
|
||||
#it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|>
|
||||
#looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
|
||||
#lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>"""
|
||||
|
||||
# TODO: tokenization is slow for some reason - here is pre-tokenized input
|
||||
suffix = [ 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, 152460, 153375, 151670, 198, 74455,
|
||||
155808, 151669, 151799, 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470, 151970, 153413,
|
||||
152419, 153334, 153289, 153374, 153199, 152040, 153260, 152721, 152680, 153297, 152419, 153248, 152400,
|
||||
152691, 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607, 152256, 152991, 152299, 152688, 153163,
|
||||
153016, 152789, 153198, 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207, 152461, 153321,
|
||||
153309, 151750, 152137, 153340, 152573, 152267, 153347, 151789, 152681, 153339, 151992, 152512, 151751,
|
||||
152179, 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, 152311, 151670, 198, 1499, 155791,
|
||||
151669, 152276, 152454, 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226, 153043, 152325,
|
||||
153267, 152622, 151670, 198, 4250, 155797, 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271,
|
||||
152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213, 152112, 153204, 151722, 152542, 151670, 198,
|
||||
19789, 155796, 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002, 152191, 151734, 152312, 152810,
|
||||
152237, 153224, 153169, 153224, 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, 152265, 151946,
|
||||
151808, 152412, 152363, 152305, 153156, 152733, 152810, 153157, 152016, 152100, 152069, 153234, 152317,
|
||||
152589, 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504, 153376, 152272, 152433, 152325,
|
||||
151941, 151670, 198, 285, 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381, 152474, 152680,
|
||||
152157, 153255, 152324, 151682, 151670, 198, 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682,
|
||||
152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, 153070, 151883, 152890, 152489, 153144,
|
||||
153375, 152358, 151685, 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, 151902, 152720,
|
||||
153377, 152027, 152378, 152821, 153207, 153459, 153028, 153068, 152507, 153255, 152158, 152921, 151958,
|
||||
152609, 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470, 152606, 152162, 152186, 153071,
|
||||
152244, 153118, 153375, 153018, 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736, 153380,
|
||||
153502, 152702, 152115, 153181, 152735, 153277, 153457, 152393, 153112, 152595, 151670, 198, 19098, 155808,
|
||||
151669, 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, 153163, 152922, 153402, 152034,
|
||||
152591, 153438, 152215, 151673, 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482, 152718,
|
||||
152862, 153347, 151670, 198, 72, 155780, 151669, 151795, 152111, 152746, 152377, 153471, 152309, 151670, 198,
|
||||
19016, 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701, 152939, 152536, 152091, 151815, 152733,
|
||||
151672, 151670, 198, 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042, 153504, 152589, 153333,
|
||||
151839, 151941, 153038, 153180, 151670, 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835, 152801,
|
||||
152985, 153400, 152393, 152818, 152765, 152249, 152600, 151699, 152302, 152752, 153018, 153009, 151992,
|
||||
153054, 152847, 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458, 152048, 152757, 152428,
|
||||
153195, 151906, 153006, 153178, 153250, 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418,
|
||||
152228, 152733, 151670, 198, 9096, 155801, 151669, 151698, 153321, 152217, 153039, 152935, 153400, 152122,
|
||||
152531, 153106, 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851, 152901, 152885, 152594,
|
||||
153446, 153080, 151670, 198, 14689, 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, 151673,
|
||||
151690, 151698, 152714, 152846, 152981, 153171, 153384, 153364, 153188, 153246, 151670, 198, 1055, 155779,
|
||||
151669, 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782, 155780, 151669, 153483, 153240, 152241,
|
||||
152558, 152697, 153046, 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605, 153034, 153434,
|
||||
153372, 153347, 151887, 152453, 152758, 152133, 152510, 152694, 152431, 152321, 153088, 152676, 152223,
|
||||
152581, 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032, 152903, 152859, 152989, 151748,
|
||||
152669, 152661, 152650, 152409, 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469, 152988,
|
||||
152894, 151819, 152391, 153019, 152058, 153062, 153230, 151826, 152112, 152306, 152264, 152769, 153390,
|
||||
152384, 152435, 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540, 151919, 151893, 152558,
|
||||
152817, 152946, 152956, 152129, 152715, 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450,
|
||||
151670, 198, 8088, 155792, 151669, 152452, 153497, 153353, 152679, 152533, 152382, 152374, 152611, 153341,
|
||||
153163, 152285, 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781, 151669, 151764, 152360, 153295,
|
||||
152634, 153342, 152199, 152271, 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, 152016, 152385,
|
||||
152629, 152495, 151826, 153321, 152958, 152180, 151886, 153432, 152922, 152128, 153024, 153040, 152593,
|
||||
152287, 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092, 152680, 153331, 151699, 152316, 152938,
|
||||
152289, 152433, 153384, 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691, 152489, 151941,
|
||||
152049, 152034, 153053, 152179, 153160, 151676, 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350,
|
||||
152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234, 153135, 152291, 153235, 152143, 152583,
|
||||
152402, 153483, 152678, 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, 152548, 153442,
|
||||
152109, 152659, 153325, 152781, 152570, 152957, 151752, 152265, 153381, 152515, 151670, 198, 437, 155787,
|
||||
151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174, 151792, 153409, 153327, 152990, 151670, 198,
|
||||
275, 155781, 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974, 151670, 198, 94273, 155799,
|
||||
151669, 152953, 152938, 153427, 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331, 152257,
|
||||
152987, 152777, 153448, 152408, 151696, 152408, 152326, 152699, 151670, 198, 385, 16239, 155828, 151669,
|
||||
152306, 152268, 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, 152918, 152923, 152467,
|
||||
152331, 153053, 153330, 151889, 153444, 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751,
|
||||
152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499, 152109, 152255, 151739, 152267, 152759,
|
||||
153318, 153165, 153349, 151670, ]
|
||||
|
||||
response = requests.post(
|
||||
host_llm + "/completion",
|
||||
json={
|
||||
"prompt": [prefix + words, *suffix],
|
||||
"n_predict": 1024,
|
||||
"cache_prompt": True,
|
||||
"return_tokens": True,
|
||||
"samplers": ["top_k"],
|
||||
"top_k": 16,
|
||||
"seed": 1003,
|
||||
}
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
#print(json.dumps(response_json, indent=4))
|
||||
#print(json.dumps(response_json["prompt"], indent=4).replace("\\n", "\n"))
|
||||
#print(json.dumps(response_json["timings"], indent=4))
|
||||
#print(json.dumps(response_json["tokens"], indent=4))
|
||||
|
||||
codes = response_json["tokens"]
|
||||
|
||||
codes = [t - 151672 for t in codes if t >= 151672 and t <= 155772]
|
||||
|
||||
response = requests.post(
|
||||
host_dec + "/embeddings",
|
||||
json={
|
||||
"input": [*codes],
|
||||
}
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
#print(json.dumps(response_json, indent=4))
|
||||
|
||||
# spectrogram
|
||||
embd = response_json[0]["embedding"]
|
||||
|
||||
n_codes = len(embd)
|
||||
n_embd = len(embd[0])
|
||||
|
||||
print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd))
|
||||
|
||||
# post-process the spectrogram to convert to audio
|
||||
# TODO: see the tts.cpp:embd_to_audio() and implement it in Python
|
||||
print('converting to audio ...')
|
||||
print('TODO: see the tts.cpp:embd_to_audio() and implement it in Python')
|
932
examples/tts/tts.cpp
Normal file
932
examples/tts/tts.cpp
Normal file
|
@ -0,0 +1,932 @@
|
|||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// Terminal utils
|
||||
//
|
||||
|
||||
#define SQR(X) ((X) * (X))
|
||||
#define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40
|
||||
|
||||
/**
|
||||
* Quantizes 24-bit RGB to xterm256 code range [16,256).
|
||||
*/
|
||||
static int rgb2xterm256(int r, int g, int b) {
|
||||
unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377};
|
||||
int av, ir, ig, ib, il, qr, qg, qb, ql;
|
||||
av = r * .299 + g * .587 + b * .114 + .5;
|
||||
ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8;
|
||||
qr = cube[(ir = UNCUBE(r))];
|
||||
qg = cube[(ig = UNCUBE(g))];
|
||||
qb = cube[(ib = UNCUBE(b))];
|
||||
if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <=
|
||||
SQR(ql - r) + SQR(ql - g) + SQR(ql - b))
|
||||
return ir * 36 + ig * 6 + ib + 020;
|
||||
return il + 0350;
|
||||
}
|
||||
|
||||
static std::string set_xterm256_foreground(int r, int g, int b) {
|
||||
int x = rgb2xterm256(r, g, b);
|
||||
std::ostringstream oss;
|
||||
oss << "\033[38;5;" << x << "m";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const std::vector<std::string> k_colors = {
|
||||
set_xterm256_foreground(220, 5, 12),
|
||||
set_xterm256_foreground(232, 96, 28),
|
||||
set_xterm256_foreground(241, 147, 45),
|
||||
set_xterm256_foreground(246, 193, 65),
|
||||
set_xterm256_foreground(247, 240, 86),
|
||||
set_xterm256_foreground(144, 201, 135),
|
||||
set_xterm256_foreground( 78, 178, 101),
|
||||
};
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
LOG("\nexample usage:\n");
|
||||
LOG("\n %s -m model.gguf -p \"Hello!\"\n", argv[0]);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
struct wav_header {
|
||||
char riff[4] = {'R', 'I', 'F', 'F'};
|
||||
uint32_t chunk_size;
|
||||
char wave[4] = {'W', 'A', 'V', 'E'};
|
||||
char fmt[4] = {'f', 'm', 't', ' '};
|
||||
uint32_t fmt_chunk_size = 16;
|
||||
uint16_t audio_format = 1; // PCM
|
||||
uint16_t num_channels = 1; // Mono
|
||||
uint32_t sample_rate;
|
||||
uint32_t byte_rate;
|
||||
uint16_t block_align;
|
||||
uint16_t bits_per_sample = 16;
|
||||
char data[4] = {'d', 'a', 't', 'a'};
|
||||
uint32_t data_size;
|
||||
};
|
||||
|
||||
static void save_wav16(const std::string & fname, const std::vector<float> & data, int sample_rate) {
|
||||
std::ofstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
LOG_ERR("%s: Failed to open file '%s' for writing", __func__, fname.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
wav_header header;
|
||||
header.sample_rate = sample_rate;
|
||||
header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8);
|
||||
header.block_align = header.num_channels * (header.bits_per_sample / 8);
|
||||
header.data_size = data.size() * (header.bits_per_sample / 8);
|
||||
header.chunk_size = 36 + header.data_size;
|
||||
|
||||
file.write(reinterpret_cast<const char*>(&header), sizeof(header));
|
||||
|
||||
for (const auto & sample : data) {
|
||||
int16_t pcm_sample = static_cast<int16_t>(std::clamp(sample * 32767.0, -32768.0, 32767.0));
|
||||
file.write(reinterpret_cast<const char*>(&pcm_sample), sizeof(pcm_sample));
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
|
||||
static void fill_hann_window(int length, bool periodic, float * output) {
|
||||
int offset = -1;
|
||||
if (periodic) {
|
||||
offset = 0;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||
}
|
||||
}
|
||||
|
||||
// very poor-man fft
|
||||
static void twiddle(float * real, float * imag, int k, int N) {
|
||||
float angle = 2 * M_PI * k / N;
|
||||
*real = cos(angle);
|
||||
*imag = sin(angle);
|
||||
}
|
||||
|
||||
static void irfft(int n, const float * inp_cplx, float * out_real) {
|
||||
int N = n / 2 + 1;
|
||||
|
||||
std::vector<float> real_input(N);
|
||||
std::vector<float> imag_input(N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
real_input[i] = inp_cplx[2 * i];
|
||||
imag_input[i] = inp_cplx[2 * i + 1];
|
||||
}
|
||||
|
||||
std::vector<float> real_output(n);
|
||||
std::vector<float> imag_output(n);
|
||||
|
||||
for (int k = 0; k < n; ++k) {
|
||||
real_output[k] = 0.0f;
|
||||
imag_output[k] = 0.0f;
|
||||
for (int m = 0; m < N; ++m) {
|
||||
float twiddle_real;
|
||||
float twiddle_imag;
|
||||
|
||||
twiddle(&twiddle_real, &twiddle_imag, k * m, n);
|
||||
|
||||
real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag;
|
||||
imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
out_real[i] = real_output[i] / N;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// y = torch.nn.functional.fold(
|
||||
// data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
||||
// )[:, 0, 0, pad:-pad]
|
||||
//
|
||||
// data.shape = torch.Size([1, 1280, 261])
|
||||
// output_size = 84480
|
||||
// win_length = 1280
|
||||
// hop_length = 320
|
||||
// pad = 480
|
||||
//
|
||||
static void fold(const std::vector<float> & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float> & output) {
|
||||
int64_t output_height = n_out;
|
||||
int64_t kernel_w = n_win;
|
||||
int64_t stride_w = n_hop;
|
||||
int64_t width = n_out;
|
||||
|
||||
output.resize(width, 0.0f);
|
||||
|
||||
int64_t col_idx = 0;
|
||||
for (int64_t w_col = 0; w_col < width; ++w_col) {
|
||||
int64_t start = w_col * stride_w - n_pad;
|
||||
int64_t end = start + kernel_w;
|
||||
|
||||
for (int64_t w_im = start; w_im < end; ++w_im) {
|
||||
if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) {
|
||||
output[w_im] += data[col_idx];
|
||||
}
|
||||
col_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
output.resize(n_out - 2 * n_pad);
|
||||
}
|
||||
|
||||
// TODO: not optimized at all
|
||||
static std::vector<float> embd_to_audio(
|
||||
const float * embd,
|
||||
const int n_codes,
|
||||
const int n_embd,
|
||||
const int n_thread) {
|
||||
const int n_fft = 1280;
|
||||
const int n_hop = 320;
|
||||
const int n_win = 1280;
|
||||
const int n_pad = (n_win - n_hop)/2;
|
||||
const int n_out = (n_codes - 1)*n_hop + n_win;
|
||||
|
||||
std::vector<float> hann(n_fft);
|
||||
|
||||
fill_hann_window(hann.size(), true, hann.data());
|
||||
|
||||
int n_spec = n_embd*n_codes;
|
||||
|
||||
std::vector<float> E (n_spec);
|
||||
std::vector<float> S (n_spec);
|
||||
std::vector<float> ST(n_spec);
|
||||
|
||||
for (int l = 0; l < n_codes; ++l) {
|
||||
for (int k = 0; k < n_embd; ++k) {
|
||||
E[k*n_codes + l] = embd[l*n_embd + k];
|
||||
}
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_embd/2; ++k) {
|
||||
for (int l = 0; l < n_codes; ++l) {
|
||||
float mag = E[(k )*n_codes + l];
|
||||
float phi = E[(k + n_embd/2)*n_codes + l];
|
||||
|
||||
mag = exp(mag);
|
||||
|
||||
if (mag > 1e2) {
|
||||
mag = 1e2;
|
||||
}
|
||||
S[2*(k*n_codes + l) + 0] = mag*cosf(phi);
|
||||
S[2*(k*n_codes + l) + 1] = mag*sinf(phi);
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < n_codes; ++l) {
|
||||
for (int k = 0; k < n_embd/2; ++k) {
|
||||
ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0];
|
||||
ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> res (n_codes*n_fft);
|
||||
std::vector<float> hann2(n_codes*n_fft);
|
||||
|
||||
std::vector<std::thread> workers(n_thread);
|
||||
for (int i = 0; i < n_thread; ++i) {
|
||||
workers[i] = std::thread([&, i]() {
|
||||
for (int l = i; l < n_codes; l += n_thread) {
|
||||
irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft);
|
||||
for (int j = 0; j < n_fft; ++j) {
|
||||
res [l*n_fft + j] *= hann[j];
|
||||
hann2[l*n_fft + j] = hann[j] * hann[j];
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
for (int i = 0; i < n_thread; ++i) {
|
||||
workers[i].join();
|
||||
}
|
||||
|
||||
std::vector<float> audio;
|
||||
std::vector<float> env;
|
||||
|
||||
fold(res, n_out, n_win, n_hop, n_pad, audio);
|
||||
fold(hann2, n_out, n_win, n_hop, n_pad, env); // TODO: can be done once
|
||||
|
||||
for (size_t i = 0; i < audio.size(); ++i) {
|
||||
audio[i] /= env[i];
|
||||
}
|
||||
|
||||
return audio;
|
||||
}
|
||||
|
||||
static const std::map<int, std::string> ones = {
|
||||
{0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"},
|
||||
{5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"},
|
||||
{10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"},
|
||||
{15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> tens = {
|
||||
{2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"},
|
||||
{6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}
|
||||
};
|
||||
|
||||
// Convert a number less than 1000 to words
|
||||
static std::string convert_less_than_thousand(int num) {
|
||||
std::string result;
|
||||
|
||||
if (num >= 100) {
|
||||
result += ones.at(num / 100) + " hundred ";
|
||||
num %= 100;
|
||||
}
|
||||
|
||||
if (num >= 20) {
|
||||
result += tens.at(num / 10);
|
||||
if (num % 10 > 0) {
|
||||
result += "-" + ones.at(num % 10);
|
||||
}
|
||||
} else if (num > 0) {
|
||||
result += ones.at(num);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string number_to_words(const std::string & number_str) {
|
||||
try {
|
||||
size_t decimal_pos = number_str.find('.');
|
||||
std::string integer_part = number_str.substr(0, decimal_pos);
|
||||
|
||||
int int_number = std::stoi(integer_part);
|
||||
std::string result;
|
||||
|
||||
if (int_number == 0) {
|
||||
result = "zero";
|
||||
} else {
|
||||
if (int_number >= 1000000000) {
|
||||
int billions = int_number / 1000000000;
|
||||
result += convert_less_than_thousand(billions) + " billion ";
|
||||
int_number %= 1000000000;
|
||||
}
|
||||
|
||||
if (int_number >= 1000000) {
|
||||
int millions = int_number / 1000000;
|
||||
result += convert_less_than_thousand(millions) + " million ";
|
||||
int_number %= 1000000;
|
||||
}
|
||||
|
||||
if (int_number >= 1000) {
|
||||
int thousands = int_number / 1000;
|
||||
result += convert_less_than_thousand(thousands) + " thousand ";
|
||||
int_number %= 1000;
|
||||
}
|
||||
|
||||
if (int_number > 0) {
|
||||
result += convert_less_than_thousand(int_number);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle decimal part
|
||||
if (decimal_pos != std::string::npos) {
|
||||
result += " point";
|
||||
std::string decimal_part = number_str.substr(decimal_pos + 1);
|
||||
for (char digit : decimal_part) {
|
||||
result += " " + ones.at(digit - '0');
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (const std::exception& e) {
|
||||
// Skip if fails
|
||||
return " ";
|
||||
}
|
||||
}
|
||||
|
||||
static std::string replace_numbers_with_words(const std::string & input_text) {
|
||||
std::regex number_pattern(R"(\d+(\.\d+)?)");
|
||||
std::string result;
|
||||
auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern);
|
||||
auto end = std::sregex_iterator();
|
||||
|
||||
size_t last_pos = 0;
|
||||
for (std::sregex_iterator i = it; i != end; ++i) {
|
||||
const std::smatch& match = *i;
|
||||
result.append(input_text, last_pos, match.position() - last_pos);
|
||||
result.append(number_to_words(match.str()));
|
||||
last_pos = match.position() + match.length();
|
||||
}
|
||||
result.append(input_text, last_pos);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
|
||||
static std::string process_text(const std::string & text) {
|
||||
|
||||
// For now I skipped text romanization as I am unsure how to handle
|
||||
// uroman and MeCab implementations in C++
|
||||
// maybe something like https://github.com/anyascii/anyascii/ could work.
|
||||
// currently only English would be supported in this function
|
||||
|
||||
std::string processed_text = replace_numbers_with_words(text);
|
||||
|
||||
std::transform(processed_text.begin(), processed_text.end(),
|
||||
processed_text.begin(), ::tolower);
|
||||
|
||||
std::regex special_chars(R"([-_/,\.\\])");
|
||||
processed_text = std::regex_replace(processed_text, special_chars, " ");
|
||||
|
||||
std::regex non_alpha(R"([^a-z\s])");
|
||||
processed_text = std::regex_replace(processed_text, non_alpha, "");
|
||||
|
||||
std::regex multiple_spaces(R"(\s+)");
|
||||
processed_text = std::regex_replace(processed_text, multiple_spaces, " ");
|
||||
|
||||
processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), "");
|
||||
|
||||
/*
|
||||
Replace spaces with the separator token same as in line 365
|
||||
|
||||
for (auto & c : prompt_user) {
|
||||
if (c == ' ') {
|
||||
prompt_clean += "<|text_sep|>";
|
||||
*/
|
||||
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
|
||||
|
||||
return processed_text;
|
||||
}
|
||||
|
||||
static void prompt_add(llama_tokens & prompt, llama_token token) {
|
||||
prompt.push_back(token);
|
||||
}
|
||||
|
||||
static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) {
|
||||
prompt.insert(prompt.end(), tokens.begin(), tokens.end());
|
||||
}
|
||||
|
||||
static void prompt_add(llama_tokens & prompt, const llama_model * model, const std::string & txt, bool add_special, bool parse_special) {
|
||||
auto tmp = common_tokenize(model, txt, add_special, parse_special);
|
||||
prompt_add(prompt, tmp);
|
||||
}
|
||||
|
||||
static void prompt_init(llama_tokens & prompt, const llama_model * model) {
|
||||
prompt.clear();
|
||||
|
||||
prompt_add(prompt, model, "<|im_start|>\n", true, true);
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.prompt = "";
|
||||
|
||||
params.n_predict = 4096;
|
||||
params.n_batch = 8192;
|
||||
params.n_ctx = 8192;
|
||||
|
||||
params.sampling.top_k = 4;
|
||||
params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, };
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int n_parallel = params.n_parallel;
|
||||
const int n_predict = params.n_predict;
|
||||
|
||||
common_init();
|
||||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model_ttc = NULL; // text-to-codes
|
||||
llama_model * model_cts = NULL; // codes-to-speech
|
||||
|
||||
llama_context * ctx_ttc = NULL;
|
||||
llama_context * ctx_cts = NULL;
|
||||
|
||||
common_init_result llama_init_ttc = common_init_from_params(params);
|
||||
model_ttc = llama_init_ttc.model;
|
||||
ctx_ttc = llama_init_ttc.context;
|
||||
|
||||
// TODO: refactor in a common struct
|
||||
params.model = params.vocoder.model;
|
||||
params.model_url = params.vocoder.model_url;
|
||||
params.hf_repo = params.vocoder.hf_repo;
|
||||
params.hf_file = params.vocoder.hf_file;
|
||||
|
||||
params.embedding = true;
|
||||
|
||||
common_init_result llama_init_cts = common_init_from_params(params);
|
||||
model_cts = llama_init_cts.model;
|
||||
ctx_cts = llama_init_cts.context;
|
||||
|
||||
std::vector<common_sampler *> smpl(n_parallel);
|
||||
for (int i = 0; i < n_parallel; ++i) {
|
||||
params.sampling.no_perf = (i != 0);
|
||||
params.sampling.seed = params.sampling.seed + 1;
|
||||
|
||||
smpl[i] = common_sampler_init(model_ttc, params.sampling);
|
||||
}
|
||||
|
||||
LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl[0]));
|
||||
LOG_INF("sampler params: \n%s\n", params.sampling.print().c_str());
|
||||
LOG_INF("sampler chain: %s\n", common_sampler_print(smpl[0]).c_str());
|
||||
|
||||
LOG_INF("%s: loading done\n", __func__);
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
std::vector<llama_token> codes;
|
||||
|
||||
// process prompt and generate voice codes
|
||||
{
|
||||
LOG_INF("%s: constructing prompt ..\n", __func__);
|
||||
|
||||
std::vector<llama_token> prompt_inp;
|
||||
|
||||
prompt_init(prompt_inp, model_ttc);
|
||||
|
||||
prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
|
||||
|
||||
// convert the input text into the necessary format expected by OuteTTS
|
||||
{
|
||||
std::string prompt_clean = process_text(params.prompt);
|
||||
|
||||
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
||||
|
||||
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
|
||||
}
|
||||
|
||||
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true);
|
||||
|
||||
// disabled to save time on tokenizing each time
|
||||
// TODO: load voices from the json files
|
||||
#if 0
|
||||
const std::string voice_data = R"(<|audio_start|>
|
||||
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
|
||||
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
|
||||
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
|
||||
from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|>
|
||||
just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|>
|
||||
two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|>
|
||||
people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|>
|
||||
is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|>
|
||||
pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|>
|
||||
remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|>
|
||||
sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|>
|
||||
i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|>
|
||||
have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|>
|
||||
some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|>
|
||||
critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|>
|
||||
about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|>
|
||||
some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|>
|
||||
of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|>
|
||||
the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|>
|
||||
gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|>
|
||||
aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|>
|
||||
but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|>
|
||||
its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|>
|
||||
still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|>
|
||||
really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|>
|
||||
enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|>
|
||||
and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|>
|
||||
it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|>
|
||||
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
|
||||
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
|
||||
|
||||
auto tmp = common_tokenize(model_ttc, voice_data, false, true);
|
||||
printf("\n\n");
|
||||
for (int i = 0; i < tmp.size(); ++i) {
|
||||
printf("%d, ", tmp[i]);
|
||||
}
|
||||
printf("\n\n");
|
||||
#else
|
||||
prompt_add(prompt_inp, llama_tokens {
|
||||
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
|
||||
152460, 153375, 151670, 198, 74455, 155808, 151669, 151799,
|
||||
151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470,
|
||||
151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040,
|
||||
153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691,
|
||||
153368, 153437, 151670, 198, 1722, 155828, 151669, 152607,
|
||||
152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198,
|
||||
152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207,
|
||||
152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267,
|
||||
153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179,
|
||||
153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904,
|
||||
152311, 151670, 198, 1499, 155791, 151669, 152276, 152454,
|
||||
153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226,
|
||||
153043, 152325, 153267, 152622, 151670, 198, 4250, 155797,
|
||||
151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271,
|
||||
152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213,
|
||||
152112, 153204, 151722, 152542, 151670, 198, 19789, 155796,
|
||||
151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002,
|
||||
152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224,
|
||||
152244, 153387, 153404, 151670, 198, 16069, 155811, 151669,
|
||||
152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733,
|
||||
152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589,
|
||||
152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504,
|
||||
153376, 152272, 152433, 152325, 151941, 151670, 198, 285,
|
||||
155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381,
|
||||
152474, 152680, 152157, 153255, 152324, 151682, 151670, 198,
|
||||
32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682,
|
||||
152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488,
|
||||
153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685,
|
||||
152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669,
|
||||
151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459,
|
||||
153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609,
|
||||
152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470,
|
||||
152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018,
|
||||
152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736,
|
||||
153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457,
|
||||
152393, 153112, 152595, 151670, 198, 19098, 155808, 151669,
|
||||
152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239,
|
||||
153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673,
|
||||
152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482,
|
||||
152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795,
|
||||
152111, 152746, 152377, 153471, 152309, 151670, 198, 19016,
|
||||
155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701,
|
||||
152939, 152536, 152091, 151815, 152733, 151672, 151670, 198,
|
||||
14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042,
|
||||
153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670,
|
||||
198, 36996, 8303, 155832, 151669, 152231, 152256, 152835,
|
||||
152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600,
|
||||
151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847,
|
||||
153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458,
|
||||
152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250,
|
||||
152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418,
|
||||
152228, 152733, 151670, 198, 9096, 155801, 151669, 151698,
|
||||
153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106,
|
||||
152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851,
|
||||
152901, 152885, 152594, 153446, 153080, 151670, 198, 14689,
|
||||
155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191,
|
||||
151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384,
|
||||
153364, 153188, 153246, 151670, 198, 1055, 155779, 151669,
|
||||
151869, 152388, 152711, 153334, 151736, 151670, 198, 1782,
|
||||
155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046,
|
||||
151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605,
|
||||
153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133,
|
||||
152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581,
|
||||
152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032,
|
||||
152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409,
|
||||
151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469,
|
||||
152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230,
|
||||
151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435,
|
||||
152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540,
|
||||
151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715,
|
||||
153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450,
|
||||
151670, 198, 8088, 155792, 151669, 152452, 153497, 153353,
|
||||
152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285,
|
||||
153411, 152495, 153141, 152320, 151670, 198, 1199, 155781,
|
||||
151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271,
|
||||
151670, 198, 43366, 155799, 151669, 152308, 151682, 152889,
|
||||
152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180,
|
||||
151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287,
|
||||
151677, 151670, 198, 53660, 155808, 151669, 151727, 152092,
|
||||
152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384,
|
||||
151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691,
|
||||
152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676,
|
||||
153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350,
|
||||
152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234,
|
||||
153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678,
|
||||
152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825,
|
||||
152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957,
|
||||
151752, 152265, 153381, 152515, 151670, 198, 437, 155787,
|
||||
151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174,
|
||||
151792, 153409, 153327, 152990, 151670, 198, 275, 155781,
|
||||
151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974,
|
||||
151670, 198, 94273, 155799, 151669, 152953, 152938, 153427,
|
||||
152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331,
|
||||
152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326,
|
||||
152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268,
|
||||
153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110,
|
||||
152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444,
|
||||
152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751,
|
||||
152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499,
|
||||
152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
|
||||
151670,});
|
||||
#endif
|
||||
|
||||
// print the prompt token-by-token
|
||||
|
||||
LOG("\n");
|
||||
|
||||
for (auto id : prompt_inp) {
|
||||
LOG("%s", common_token_to_piece(ctx_ttc, id).c_str());
|
||||
}
|
||||
|
||||
LOG_INF("%s: prompt size: %d\n", __func__, (int) prompt_inp.size());
|
||||
|
||||
LOG("\n");
|
||||
|
||||
// create a llama_batch
|
||||
// we use this object to submit token data for decoding
|
||||
llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel);
|
||||
|
||||
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
seq_ids[i] = i;
|
||||
}
|
||||
|
||||
// evaluate the initial prompt
|
||||
for (size_t i = 0; i < prompt_inp.size(); ++i) {
|
||||
common_batch_add(batch, prompt_inp[i], i, seq_ids, false);
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx_ttc, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (n_parallel > 1) {
|
||||
LOG_INF("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
|
||||
}
|
||||
|
||||
llama_synchronize(ctx_ttc);
|
||||
|
||||
LOG_INF("%s: time for prompt: %.3f ms\n\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
// main loop
|
||||
|
||||
// remember the batch index of the last token for each parallel sequence
|
||||
// we need this to determine which logits to sample from
|
||||
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
||||
|
||||
int n_past = batch.n_tokens;
|
||||
int n_decode = 0;
|
||||
|
||||
while (n_decode <= n_predict) {
|
||||
// prepare the next batch
|
||||
common_batch_clear(batch);
|
||||
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
if (i_batch[i] < 0) {
|
||||
// the stream has already finished
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
|
||||
|
||||
common_sampler_accept(smpl[i], new_token_id, true);
|
||||
|
||||
codes.push_back(new_token_id);
|
||||
|
||||
const auto * cands = common_sampler_get_candidates(smpl[i]);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_token_is_eog(model_ttc, new_token_id) || n_decode == n_predict) {
|
||||
std::string reason;
|
||||
if (llama_token_is_eog(model_ttc, new_token_id)) {
|
||||
reason = "eos";
|
||||
} else {
|
||||
reason = "n_predict";
|
||||
}
|
||||
|
||||
i_batch[i] = -1;
|
||||
|
||||
LOG("\n");
|
||||
if (n_parallel > 1) {
|
||||
LOG_CNT("\n");
|
||||
LOG_INF("%s: stream %d finished at n_past = %d, reason = '%s'\n", __func__, i, n_past, reason.c_str());
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
{
|
||||
const float p = cands->data[cands->selected].p;
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) ((3*p)*float(k_colors.size()))));
|
||||
|
||||
LOG_CNT("%s%d%s", k_colors[col].c_str(), i, "\033[0m");
|
||||
//LOG_CNT("%d", i);
|
||||
}
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
|
||||
// push this new token for next evaluation
|
||||
common_batch_add(batch, new_token_id, n_past, { i }, true);
|
||||
}
|
||||
|
||||
// all streams are finished
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
n_decode += 1;
|
||||
n_past += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx_ttc, batch)) {
|
||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
LOG("\n");
|
||||
LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
|
||||
}
|
||||
|
||||
common_perf_print(ctx_ttc, smpl[0]);
|
||||
|
||||
//std::vector<llama_token> codes = {198, 88225, 155856, 151669, 152205,
|
||||
// 153064, 152537, 153421, 153209, 152524, 151689, 152993, 152438, 152695,
|
||||
// 153091, 152945, 152829, 152534, 152934, 153020, 151997, 152263, 153010,
|
||||
// 153146, 152399, 153208, 152496, 151793, 152848, 152263, 152571, 153286,
|
||||
// 152227, 153300, 152934, 152263, 153208, 152263, 152965, 152430, 152296,
|
||||
// 153146, 152920, 152376, 152556, 153363, 151775, 152044, 152972, 152690,
|
||||
// 153379, 152368, 152233, 153422, 152490, 151996, 152022, 151694, 152061,
|
||||
// 153238, 152539, 153356, 152640, 153021, 153123, 151962, 153094, 151670,
|
||||
// 198, 20339, 13189, 155824, 151669, 152070, 152007, 152910, 151683,
|
||||
// 152000, 152373, 152760, 152046, 151735, 152334, 152394, 153073, 152908,
|
||||
// 151856, 151953, 153247, 153293, 151903, 153480, 153168, 152478, 153359,
|
||||
// 153429, 151905, 151678, 152567, 152411, 152165, 152556, 153075, 153424,
|
||||
// 151993, 152999, 153078, 152151, 152088, 153389, 152484, 151874, 151670,
|
||||
// 198, 285, 155784, 151669, 152226, 152126, 152638, 153215, 151729,
|
||||
// 152959, 153479, 153059, 151838, 151670, 198, 1782, 155783, 151669,
|
||||
// 153288, 153055, 153314, 152497, 152962, 152741, 152076, 153253, 151670,
|
||||
// 198, 471, 16488, 155825, 151669, 152060, 152916, 151893, 153469, 152501,
|
||||
// 152080, 152743, 151932, 153161, 152096, 152761, 152698, 153401, 153242,
|
||||
// 153336, 152441, 152838, 153467, 152706, 153496, 153310, 152422, 153360,
|
||||
// 153115, 152763, 151998, 152373, 153450, 152554, 151968, 153323, 152055,
|
||||
// 152468, 153111, 153358, 152813, 152010, 151770, 152823, 152960, 151670,
|
||||
// 198, 22627, 155823, 151669, 152814, 152366, 153484, 152931, 153441,
|
||||
// 152164, 152877, 152915, 153463, 151692, 152911, 152747, 152776, 151831,
|
||||
// 153449, 151882, 152975, 152031, 152513, 153150, 152448, 152667, 153133,
|
||||
// 153189, 152619, 153466, 152054, 152106, 153119, 152277, 152439, 153109,
|
||||
// 152997, 152141, 153154, 153256, 153311, 151922, 151670, 198, 1055,
|
||||
// 155781, 151669, 152633, 151850, 153060, 153270, 152560, 153348, 152729,
|
||||
// 151670, 198, 25312, 155803, 151669, 152521, 153403, 152561, 153337,
|
||||
// 153383, 152199, 153493, 153326, 151830, 152254, 152248, 152349, 152153,
|
||||
// 153007, 151823, 153037, 152575, 152457, 152406, 152592, 153116, 153365,
|
||||
// 153456, 151670, 198, 88225, 155817, 151669, 153271, 151925, 152218,
|
||||
// 152418, 152253, 153140, 151903, 153151, 152626, 152338, 152647, 153464,
|
||||
// 152785, 152768, 151711, 152037, 152033, 151804, 152216, 151701, 151855,
|
||||
// 152348, 152995, 152955, 152905, 152342, 152340, 153391, 153453, 152418,
|
||||
// 153415, 151990, 153083, 152884, 151670, 198, 151668, 198, 151645};
|
||||
|
||||
{
|
||||
const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
|
||||
|
||||
LOG("\n");
|
||||
LOG_INF("codes: '%s'\n", inp_txt.c_str());
|
||||
LOG_INF("%s: codes size: %d\n", __func__, (int) codes.size());
|
||||
}
|
||||
|
||||
// remove all non-audio tokens (i.e. < 151672 || > 155772)
|
||||
codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
|
||||
|
||||
{
|
||||
const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
|
||||
LOG_INF("codes audio: '%s'\n", inp_txt.c_str());
|
||||
LOG_INF("%s: codes audio size: %d\n", __func__, (int) codes.size());
|
||||
}
|
||||
|
||||
for (auto & token : codes) {
|
||||
token -= 151672;
|
||||
}
|
||||
|
||||
const auto t_voc_start = ggml_time_us();
|
||||
|
||||
const int n_codes = codes.size();
|
||||
|
||||
llama_batch batch = llama_batch_init(n_codes, 0, 1);
|
||||
|
||||
for (size_t i = 0; i < codes.size(); ++i) {
|
||||
common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == n_codes);
|
||||
|
||||
if (llama_decode(ctx_cts, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx_cts);
|
||||
|
||||
LOG_INF("%s: time for vocoder: %.3f ms\n", __func__, (ggml_time_us() - t_voc_start) / 1000.0f);
|
||||
|
||||
const auto t_spec_start = ggml_time_us();
|
||||
|
||||
#if 1
|
||||
// spectral operations
|
||||
const int n_embd = llama_n_embd(model_cts);
|
||||
const float * embd = llama_get_embeddings(ctx_cts);
|
||||
|
||||
auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads);
|
||||
|
||||
#else
|
||||
// read the spectrogram from a file for debugging purposes
|
||||
std::vector<float> audio;
|
||||
{
|
||||
std::ifstream fin("out.bin", std::ios::binary);
|
||||
if (!fin) {
|
||||
LOG_ERR("%s: failed to open file '%s'\n", __func__, "out.bin");
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<float> embd;
|
||||
|
||||
int n_codes;
|
||||
int n_embd;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_codes), sizeof(int));
|
||||
fin.read(reinterpret_cast<char *>(&n_embd), sizeof(int));
|
||||
|
||||
embd.resize(n_codes * n_embd);
|
||||
fin.read(reinterpret_cast<char *>(embd.data()), n_codes * n_embd * sizeof(float));
|
||||
fin.close();
|
||||
|
||||
LOG_INF("%s: n_codes: %d, n_embd: %d\n", __func__, n_codes, n_embd);
|
||||
|
||||
audio = embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
|
||||
}
|
||||
#endif
|
||||
|
||||
const std::string fname = "output.wav";
|
||||
|
||||
const int n_sr = 24000; // sampling rate
|
||||
|
||||
// zero out first 0.25 seconds
|
||||
for (int i = 0; i < 24000/4; ++i) {
|
||||
audio[i] = 0.0f;
|
||||
}
|
||||
|
||||
LOG_INF("%s: time for spectral ops: %.3f ms\n", __func__, (ggml_time_us() - t_spec_start) / 1000.0f);
|
||||
LOG_INF("%s: total time: %.3f ms\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
|
||||
|
||||
save_wav16(fname, audio, n_sr);
|
||||
|
||||
LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str());
|
||||
|
||||
llama_free(ctx_ttc);
|
||||
llama_free_model(model_ttc);
|
||||
|
||||
llama_free(ctx_cts);
|
||||
llama_free_model(model_cts);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -74,10 +74,10 @@ if (NOT GGML_CUDA_GRAPHS_DEFAULT)
|
|||
endif()
|
||||
|
||||
# general
|
||||
option(GGML_STATIC "ggml: static link libraries" OFF)
|
||||
option(GGML_NATIVE "ggml: enable -march=native flag" ${GGML_NATIVE_DEFAULT})
|
||||
option(GGML_LTO "ggml: enable link time optimization" OFF)
|
||||
option(GGML_CCACHE "ggml: use ccache if available" ON)
|
||||
option(GGML_STATIC "ggml: static link libraries" OFF)
|
||||
option(GGML_NATIVE "ggml: optimize the build for the current system" ${GGML_NATIVE_DEFAULT})
|
||||
option(GGML_LTO "ggml: enable link time optimization" OFF)
|
||||
option(GGML_CCACHE "ggml: use ccache if available" ON)
|
||||
|
||||
# debug
|
||||
option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON)
|
||||
|
@ -120,8 +120,9 @@ endif()
|
|||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_SVE "ggml: enable SVE" OFF)
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
|
||||
|
||||
if (WIN32)
|
||||
|
|
|
@ -1564,17 +1564,6 @@ extern "C" {
|
|||
int d1, // dilation dimension 1
|
||||
bool is_2D);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
|
@ -1592,6 +1581,23 @@ extern "C" {
|
|||
int s, // stride
|
||||
int d); // dilation
|
||||
|
||||
// depthwise
|
||||
// TODO: this is very likely wrong for some cases! - needs more testing
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d_dw(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride
|
||||
int p0, // padding
|
||||
int d0); // dilation
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride
|
||||
int d0); // dilation
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
|
@ -1611,7 +1617,6 @@ extern "C" {
|
|||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
|
||||
// kernel size is a->ne[0] x a->ne[1]
|
||||
// stride is equal to kernel size
|
||||
// padding is zero
|
||||
|
@ -1638,6 +1643,18 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// depthwise
|
||||
GGML_API struct ggml_tensor * ggml_conv_2d_dw(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
|
@ -534,7 +534,6 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
size_t offset = ggml_dyn_tallocr_alloc(alloc, size, node);
|
||||
hn->buffer_id = buffer_id;
|
||||
hn->offset = offset;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -74,112 +74,77 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
|
||||
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
|
||||
CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
|
||||
(NOT CMAKE_OSX_ARCHITECTURES AND
|
||||
NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
|
||||
|
||||
message(STATUS "ARM detected")
|
||||
|
||||
if (MSVC)
|
||||
list(APPEND ARCH_DEFINITIONS __aarch64__) # MSVC defines _M_ARM64 instead
|
||||
list(APPEND ARCH_DEFINITIONS __ARM_NEON)
|
||||
list(APPEND ARCH_DEFINITIONS __ARM_FEATURE_FMA)
|
||||
|
||||
set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})
|
||||
string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2")
|
||||
|
||||
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD)
|
||||
if (GGML_COMPILER_SUPPORT_DOTPROD)
|
||||
list(APPEND ARCH_DEFINITIONS __ARM_FEATURE_DOTPROD)
|
||||
|
||||
message(STATUS "ARM feature DOTPROD enabled")
|
||||
endif ()
|
||||
|
||||
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
|
||||
|
||||
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
|
||||
list(APPEND ARCH_DEFINITIONS __ARM_FEATURE_MATMUL_INT8)
|
||||
|
||||
message(STATUS "ARM feature MATMUL_INT8 enabled")
|
||||
endif ()
|
||||
|
||||
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
|
||||
if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
|
||||
list(APPEND ARCH_DEFINITIONS __ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
|
||||
|
||||
message(STATUS "ARM feature FP16_VECTOR_ARITHMETIC enabled")
|
||||
endif ()
|
||||
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
|
||||
elseif (APPLE)
|
||||
if (GGML_NATIVE)
|
||||
set(USER_PROVIDED_MARCH FALSE)
|
||||
foreach(flag_var IN ITEMS CMAKE_C_FLAGS CMAKE_CXX_FLAGS CMAKE_REQUIRED_FLAGS)
|
||||
if ("${${flag_var}}" MATCHES "-march=[a-zA-Z0-9+._-]+")
|
||||
set(USER_PROVIDED_MARCH TRUE)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if (NOT USER_PROVIDED_MARCH)
|
||||
set(MARCH_FLAGS "-march=armv8.2a")
|
||||
|
||||
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD)
|
||||
if (GGML_COMPILER_SUPPORT_DOTPROD)
|
||||
set(MARCH_FLAGS "${MARCH_FLAGS}+dotprod")
|
||||
list(APPEND ARCH_DEFINITIONS __ARM_FEATURE_DOTPROD)
|
||||
|
||||
message(STATUS "ARM feature DOTPROD enabled")
|
||||
endif ()
|
||||
|
||||
set(TEST_I8MM_FLAGS "-march=armv8.2a+i8mm")
|
||||
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${TEST_I8MM_FLAGS}")
|
||||
|
||||
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
|
||||
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
|
||||
set(MARCH_FLAGS "${MARCH_FLAGS}+i8mm")
|
||||
list(APPEND ARCH_DEFINITIONS __ARM_FEATURE_MATMUL_INT8)
|
||||
|
||||
message(STATUS "ARM feature MATMUL_INT8 enabled")
|
||||
endif ()
|
||||
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
|
||||
list(APPEND ARCH_FLAGS "${MARCH_FLAGS}")
|
||||
endif ()
|
||||
endif ()
|
||||
if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
||||
message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
|
||||
else()
|
||||
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
|
||||
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
|
||||
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
|
||||
endif()
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
|
||||
# Raspberry Pi 1, Zero
|
||||
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)
|
||||
endif()
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
|
||||
if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
|
||||
# Android armeabi-v7a
|
||||
list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)
|
||||
else()
|
||||
# Raspberry Pi 2
|
||||
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
|
||||
|
||||
if (GGML_NATIVE)
|
||||
list(APPEND ARCH_FLAGS -mcpu=native)
|
||||
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
|
||||
# -mcpu=native does not always enable all the features in some compilers,
|
||||
# so we check for them manually and enable them if available
|
||||
|
||||
include(CheckCXXSourceRuns)
|
||||
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}+dotprod")
|
||||
check_cxx_source_runs(
|
||||
"#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }"
|
||||
GGML_COMPILER_SUPPORT_DOTPROD)
|
||||
if (GGML_COMPILER_SUPPORT_DOTPROD)
|
||||
set(ARCH_FLAGS "${ARCH_FLAGS}+dotprod")
|
||||
endif()
|
||||
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}+i8mm")
|
||||
check_cxx_source_runs(
|
||||
"#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }"
|
||||
GGML_COMPILER_SUPPORT_I8MM)
|
||||
if (GGML_COMPILER_SUPPORT_I8MM)
|
||||
set(ARCH_FLAGS "${ARCH_FLAGS}+i8mm")
|
||||
endif()
|
||||
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
|
||||
else()
|
||||
if (GGML_CPU_ARM_ARCH)
|
||||
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
|
||||
endif()
|
||||
endif()
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
|
||||
# Android arm64-v8a
|
||||
# Raspberry Pi 3, 4, Zero 2 (32-bit)
|
||||
list(APPEND ARCH_FLAGS -mno-unaligned-access)
|
||||
endif()
|
||||
if (GGML_SVE)
|
||||
list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
|
||||
|
||||
# show enabled features
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
|
||||
INPUT_FILE "/dev/null"
|
||||
OUTPUT_VARIABLE ARM_FEATURE
|
||||
RESULT_VARIABLE ARM_FEATURE_RESULT
|
||||
)
|
||||
if (ARM_FEATURE_RESULT)
|
||||
message(FATAL_ERROR "Failed to get ARM features")
|
||||
else()
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
|
||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
||||
if (NOT ${feature_pos} EQUAL -1)
|
||||
message(STATUS "ARM feature ${feature} enabled")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$"))
|
||||
|
||||
message(STATUS "x86 detected")
|
||||
|
||||
if (MSVC)
|
||||
# instruction set detection for MSVC only
|
||||
if (GGML_NATIVE)
|
||||
|
|
|
@ -394,8 +394,11 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
|||
switch (op->op) {
|
||||
case GGML_OP_CPY:
|
||||
return
|
||||
op->type != GGML_TYPE_IQ3_XXS &&
|
||||
op->type != GGML_TYPE_IQ3_S &&
|
||||
op->type != GGML_TYPE_IQ2_XXS &&
|
||||
op->type != GGML_TYPE_IQ2_XS &&
|
||||
op->type != GGML_TYPE_IQ2_S &&
|
||||
op->type != GGML_TYPE_IQ1_S &&
|
||||
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
|
||||
case GGML_OP_MUL_MAT:
|
||||
|
@ -519,6 +522,12 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|||
if (ggml_cpu_has_sve()) {
|
||||
features.push_back({ "SVE", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_dotprod()) {
|
||||
features.push_back({ "DOTPROD", "1" });
|
||||
}
|
||||
if (ggml_cpu_has_matmul_int8()) {
|
||||
features.push_back({ "MATMUL_INT8", "1" });
|
||||
}
|
||||
if (ggml_cpu_get_sve_cnt() > 0) {
|
||||
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
|
||||
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
|
||||
|
|
|
@ -204,6 +204,7 @@ template <> inline float32x4_t load(const float *p) {
|
|||
return vld1q_f32(p);
|
||||
}
|
||||
#if !defined(_MSC_VER)
|
||||
// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
template <> inline float16x8_t load(const ggml_fp16_t *p) {
|
||||
return vld1q_f16((const float16_t *)p);
|
||||
}
|
||||
|
|
|
@ -551,6 +551,22 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
|||
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
||||
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
||||
|
||||
// expose GGUF internals for test code
|
||||
|
||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||
|
||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||
|
||||
struct gguf_buf {
|
||||
void * data;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
GGML_API struct gguf_buf gguf_buf_init(size_t size);
|
||||
GGML_API void gguf_buf_free(struct gguf_buf buf);
|
||||
|
||||
GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -245,6 +245,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||
|
@ -528,6 +529,13 @@ struct vk_op_pool2d_push_constants {
|
|||
int32_t p0; int32_t p1;
|
||||
};
|
||||
|
||||
struct vk_op_rwkv_wkv6_push_constants {
|
||||
uint32_t B;
|
||||
uint32_t T;
|
||||
uint32_t C;
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
|
@ -1363,7 +1371,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
|
||||
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
const uint32_t warps = warptile[0] / device->subgroup_size;
|
||||
const uint32_t warps = warptile[0] / warptile[10];
|
||||
|
||||
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
|
||||
|
@ -1377,8 +1385,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
std::cerr << "ggml_vulkan: Compiling shaders";
|
||||
|
||||
// some shaders require the subgroup size to be 16 or larger
|
||||
// some shaders have a minimum subgroup size
|
||||
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
|
||||
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
|
||||
|
||||
// mulmat
|
||||
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
||||
|
@ -1445,7 +1454,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
|
||||
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
|
||||
s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
|
||||
|
||||
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
||||
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
||||
|
@ -1864,7 +1873,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
|
@ -1878,7 +1887,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
|
@ -1892,7 +1901,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
||||
|
||||
// dequant shaders
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
|
@ -2014,6 +2023,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
@ -5022,6 +5033,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_pool2d_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_rwkv_wkv6_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_leaky_relu_f32;
|
||||
|
@ -5424,6 +5440,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
|
||||
const ggml_tensor * k = dst->src[0];
|
||||
const ggml_tensor * v = dst->src[1];
|
||||
const ggml_tensor * r = dst->src[2];
|
||||
const ggml_tensor * tf = dst->src[3];
|
||||
const ggml_tensor * td = dst->src[4];
|
||||
const ggml_tensor * state = dst->src[5];
|
||||
|
||||
GGML_ASSERT(!ggml_is_quantized(k->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(v->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(r->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(tf->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(td->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(state->type));
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
if (dryrun) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
||||
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
||||
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
|
||||
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
|
||||
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
||||
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
||||
vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
|
||||
uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
|
||||
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
|
||||
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
|
||||
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
|
||||
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
|
||||
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
||||
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
|
||||
K_uma = d_K != nullptr;
|
||||
V_uma = d_V != nullptr;
|
||||
R_uma = d_R != nullptr;
|
||||
TF_uma = d_TF != nullptr;
|
||||
TD_uma = d_TD != nullptr;
|
||||
STATE_uma = d_State != nullptr;
|
||||
DST_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
if (!K_uma) {
|
||||
d_K = k_buf_ctx->dev_buffer;
|
||||
k_offset = vk_tensor_offset(k) + k->view_offs;
|
||||
}
|
||||
if (!V_uma) {
|
||||
d_V = v_buf_ctx->dev_buffer;
|
||||
v_offset = vk_tensor_offset(v) + v->view_offs;
|
||||
}
|
||||
if (!R_uma) {
|
||||
d_R = r_buf_ctx->dev_buffer;
|
||||
r_offset = vk_tensor_offset(r) + r->view_offs;
|
||||
}
|
||||
if (!TF_uma) {
|
||||
d_TF = tf_buf_ctx->dev_buffer;
|
||||
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
||||
}
|
||||
if (!TD_uma) {
|
||||
d_TD = td_buf_ctx->dev_buffer;
|
||||
td_offset = vk_tensor_offset(td) + td->view_offs;
|
||||
}
|
||||
if (!STATE_uma) {
|
||||
d_State = state_buf_ctx->dev_buffer;
|
||||
state_offset = vk_tensor_offset(state) + state->view_offs;
|
||||
}
|
||||
if (!DST_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
|
||||
const uint64_t k_size = ggml_nbytes(k);
|
||||
const uint64_t v_size = ggml_nbytes(v);
|
||||
const uint64_t r_size = ggml_nbytes(r);
|
||||
const uint64_t tf_size = ggml_nbytes(tf);
|
||||
const uint64_t td_size = ggml_nbytes(td);
|
||||
const uint64_t state_size = ggml_nbytes(state);
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
|
||||
std::array<uint32_t, 3> elements = {
|
||||
(uint32_t)(pc.B * pc.H),
|
||||
1,
|
||||
1
|
||||
};
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_K, k_offset, k_size },
|
||||
vk_subbuffer{ d_V, v_offset, v_size },
|
||||
vk_subbuffer{ d_R, r_offset, r_size },
|
||||
vk_subbuffer{ d_TF, tf_offset, tf_size },
|
||||
vk_subbuffer{ d_TD, td_offset, td_size },
|
||||
vk_subbuffer{ d_State, state_offset, state_size },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t seq_length = dst->src[0]->ne[3];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[2];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_rwkv6(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
int * op_params = (int *)dst->op_params;
|
||||
|
||||
|
@ -6569,6 +6713,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
break;
|
||||
|
@ -6768,6 +6913,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
|
@ -6848,6 +6998,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
buf = tensor->buffer;
|
||||
|
@ -7724,6 +7875,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return true;
|
||||
default:
|
||||
|
@ -8300,7 +8452,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
||||
const float * op_params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
||||
} else {
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
||||
tensor->src[4], tensor->src[5]);
|
||||
}
|
||||
else {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
|
|||
void soft_max(uint num_iters) {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint rowy = rowx % p.KY;
|
||||
const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
|
||||
|
||||
if (rowx >= p.nrows_x) {
|
||||
return;
|
||||
|
|
|
@ -479,6 +479,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
87
ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp
Normal file
87
ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp
Normal file
|
@ -0,0 +1,87 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#define BLOCK_SIZE 64
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint B;
|
||||
uint T;
|
||||
uint C;
|
||||
uint H;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; };
|
||||
layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; };
|
||||
layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; };
|
||||
layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; };
|
||||
layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; };
|
||||
layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; };
|
||||
layout(binding = 6) buffer DstBuf { A_TYPE dst[]; };
|
||||
|
||||
shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint head_size = BLOCK_SIZE;
|
||||
const uint batch_id = gl_WorkGroupID.x / H;
|
||||
const uint head_id = gl_WorkGroupID.x % H;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
A_TYPE state[BLOCK_SIZE];
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid];
|
||||
}
|
||||
|
||||
barrier();
|
||||
_tf[tid] = tf[head_id * head_size + tid];
|
||||
barrier();
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
barrier();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
barrier();
|
||||
|
||||
const A_TYPE v_val = v[t];
|
||||
A_TYPE y = 0.0;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
vec4 kv = k_vec * v_val;
|
||||
|
||||
vec4 temp = tf_vec * kv + s_vec;
|
||||
y += dot(r_vec, temp);
|
||||
|
||||
s_vec = s_vec * td_vec + kv;
|
||||
state[j] = s_vec.x;
|
||||
state[j+1] = s_vec.y;
|
||||
state[j+2] = s_vec.z;
|
||||
state[j+3] = s_vec.w;
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
295
ggml/src/ggml.c
295
ggml/src/ggml.c
|
@ -3760,104 +3760,10 @@ struct ggml_tensor * ggml_clamp(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_1d
|
||||
|
||||
static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
|
||||
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
||||
}
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int p0,
|
||||
int d0) {
|
||||
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
|
||||
|
||||
struct ggml_tensor * result =
|
||||
ggml_mul_mat(ctx,
|
||||
ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
|
||||
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
|
||||
|
||||
result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_1d_ph
|
||||
|
||||
struct ggml_tensor* ggml_conv_1d_ph(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s,
|
||||
int d) {
|
||||
return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
|
||||
}
|
||||
|
||||
// ggml_conv_transpose_1d
|
||||
|
||||
static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
|
||||
return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
|
||||
}
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int p0,
|
||||
int d0) {
|
||||
GGML_ASSERT(ggml_is_matrix(b));
|
||||
GGML_ASSERT(a->ne[2] == b->ne[1]);
|
||||
GGML_ASSERT(a->ne[3] == 1);
|
||||
|
||||
GGML_ASSERT(p0 == 0);
|
||||
GGML_ASSERT(d0 == 1);
|
||||
|
||||
const int64_t ne[4] = {
|
||||
ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
|
||||
a->ne[1], b->ne[2], 1,
|
||||
};
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
int32_t params[] = { s0, p0, d0 };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_CONV_TRANSPOSE_1D;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_depthwise
|
||||
|
||||
struct ggml_tensor * ggml_conv_depthwise_2d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1) {
|
||||
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
|
||||
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
|
||||
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
|
||||
s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
|
||||
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
|
||||
|
||||
new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
|
||||
struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
|
||||
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
|
||||
|
||||
return result;
|
||||
}
|
||||
// ggml_conv_2d
|
||||
|
||||
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
||||
// a: [OC,IC, KH, KW]
|
||||
// b: [N, IC, IH, IW]
|
||||
|
@ -3874,10 +3780,11 @@ struct ggml_tensor * ggml_im2col(
|
|||
int d1,
|
||||
bool is_2D,
|
||||
enum ggml_type dst_type) {
|
||||
if(is_2D) {
|
||||
if (is_2D) {
|
||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||
} else {
|
||||
GGML_ASSERT(a->ne[1] == b->ne[1]);
|
||||
//GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
|
||||
GGML_ASSERT(b->ne[1] == a->ne[1]);
|
||||
GGML_ASSERT(b->ne[3] == 1);
|
||||
}
|
||||
|
||||
|
@ -3928,6 +3835,108 @@ struct ggml_tensor * ggml_im2col_back(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_1d
|
||||
|
||||
struct ggml_tensor * ggml_conv_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int p0,
|
||||
int d0) {
|
||||
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
|
||||
|
||||
struct ggml_tensor * result =
|
||||
ggml_mul_mat(ctx,
|
||||
ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
|
||||
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
|
||||
|
||||
result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_1d_ph
|
||||
|
||||
struct ggml_tensor* ggml_conv_1d_ph(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s,
|
||||
int d) {
|
||||
return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
|
||||
}
|
||||
|
||||
// ggml_conv_1d_dw
|
||||
|
||||
struct ggml_tensor * ggml_conv_1d_dw(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int p0,
|
||||
int d0) {
|
||||
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
|
||||
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
|
||||
|
||||
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
|
||||
|
||||
struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
|
||||
|
||||
result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_1d_dw_ph
|
||||
|
||||
struct ggml_tensor * ggml_conv_1d_dw_ph(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int d0) {
|
||||
return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
|
||||
}
|
||||
|
||||
// ggml_conv_transpose_1d
|
||||
|
||||
static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
|
||||
return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
|
||||
}
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int p0,
|
||||
int d0) {
|
||||
GGML_ASSERT(ggml_is_matrix(b));
|
||||
GGML_ASSERT(a->ne[2] == b->ne[1]);
|
||||
GGML_ASSERT(a->ne[3] == 1);
|
||||
|
||||
GGML_ASSERT(p0 == 0);
|
||||
GGML_ASSERT(d0 == 1);
|
||||
|
||||
const int64_t ne[4] = {
|
||||
ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
|
||||
a->ne[1], b->ne[2], 1,
|
||||
};
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
int32_t params[] = { s0, p0, d0 };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_CONV_TRANSPOSE_1D;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_2d
|
||||
|
||||
// a: [OC,IC, KH, KW]
|
||||
// b: [N, IC, IH, IW]
|
||||
// result: [N, OC, OH, OW]
|
||||
|
@ -3973,6 +3982,31 @@ struct ggml_tensor * ggml_conv_2d_s1_ph(
|
|||
return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
|
||||
}
|
||||
|
||||
// ggml_conv_2d_dw
|
||||
|
||||
struct ggml_tensor * ggml_conv_2d_dw(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1) {
|
||||
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
|
||||
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
|
||||
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
|
||||
s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
|
||||
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
|
||||
|
||||
new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
|
||||
struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
|
||||
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_transpose_2d_p0
|
||||
|
||||
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
|
||||
|
@ -6037,12 +6071,12 @@ struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, co
|
|||
|
||||
struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
|
||||
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL;
|
||||
return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
|
||||
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL;
|
||||
return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL;
|
||||
}
|
||||
|
||||
void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
||||
|
@ -6489,7 +6523,7 @@ struct gguf_context {
|
|||
void * data;
|
||||
};
|
||||
|
||||
static size_t gguf_type_size(enum gguf_type type) {
|
||||
size_t gguf_type_size(enum gguf_type type) {
|
||||
GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
|
||||
return GGUF_TYPE_SIZE[type];
|
||||
}
|
||||
|
@ -6617,13 +6651,7 @@ struct gguf_context * gguf_init_empty(void) {
|
|||
return ctx;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
|
||||
FILE * file = ggml_fopen(fname, "rb");
|
||||
if (!file) {
|
||||
fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
|
||||
return NULL;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
|
||||
// offset from start of file
|
||||
size_t offset = 0;
|
||||
|
||||
|
@ -6636,7 +6664,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
for (uint32_t i = 0; i < sizeof(magic); i++) {
|
||||
if (magic[i] != GGUF_MAGIC[i]) {
|
||||
fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
|
||||
fclose(file);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
@ -6647,7 +6674,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
|
||||
fclose(file);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
@ -6665,7 +6691,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
|
||||
if (ctx->header.version == 1) {
|
||||
fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6678,7 +6703,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read header\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6688,12 +6712,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
{
|
||||
const uint64_t n_kv = ctx->header.n_kv;
|
||||
|
||||
ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
|
||||
if (!ctx->kv) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
if (n_kv > 0) {
|
||||
ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
|
||||
if (!ctx->kv) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < n_kv; ++i) {
|
||||
|
@ -6740,7 +6765,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
// prevent from integer overflow in the malloc below
|
||||
if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
|
||||
fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6748,7 +6772,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
|
||||
if (!kv->value.arr.data) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6760,7 +6783,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
// prevent from integer overflow in the malloc below
|
||||
if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
|
||||
fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6768,7 +6790,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
|
||||
if (!kv->value.arr.data) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6799,7 +6820,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6810,7 +6830,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
|
||||
if (!ctx->infos) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6846,7 +6865,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6889,7 +6907,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
// this tensor type support have been removed:
|
||||
fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
|
||||
__func__, info->name.data, (int) info->type, ggml_type_name(info->type));
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6897,7 +6914,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
if (ne % ggml_blck_size(info->type) != 0) {
|
||||
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
|
||||
__func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6929,7 +6945,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
*params.ctx = ggml_init(pdata);
|
||||
if (*params.ctx == NULL) {
|
||||
fprintf(stderr, "%s: failed to initialize context\n", __func__);
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -6948,7 +6963,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read tensor data\n", __func__);
|
||||
fclose(file);
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
|
@ -6987,7 +7001,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
|
||||
fclose(file);
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
|
@ -6996,11 +7009,21 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
ggml_set_no_alloc(ctx_data, params.no_alloc);
|
||||
}
|
||||
|
||||
fclose(file);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
|
||||
FILE * file = ggml_fopen(fname, "rb");
|
||||
if (!file) {
|
||||
fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
|
||||
return NULL;
|
||||
}
|
||||
|
||||
struct gguf_context * result = gguf_init_from_file_impl(file, params);
|
||||
fclose(file);
|
||||
return result;
|
||||
}
|
||||
|
||||
void gguf_free(struct gguf_context * ctx) {
|
||||
if (ctx == NULL) {
|
||||
return;
|
||||
|
@ -7460,13 +7483,7 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo
|
|||
// fwrite(val, sizeof(char), size, file);
|
||||
//}
|
||||
|
||||
struct gguf_buf {
|
||||
void * data;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
|
||||
static struct gguf_buf gguf_buf_init(size_t size) {
|
||||
struct gguf_buf gguf_buf_init(size_t size) {
|
||||
struct gguf_buf buf = {
|
||||
/*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
|
||||
/*buf.size =*/ size,
|
||||
|
@ -7476,7 +7493,7 @@ static struct gguf_buf gguf_buf_init(size_t size) {
|
|||
return buf;
|
||||
}
|
||||
|
||||
static void gguf_buf_free(struct gguf_buf buf) {
|
||||
void gguf_buf_free(struct gguf_buf buf) {
|
||||
if (buf.data) {
|
||||
GGML_FREE(buf.data);
|
||||
}
|
||||
|
@ -7514,7 +7531,7 @@ static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_si
|
|||
buf->offset += el_size;
|
||||
}
|
||||
|
||||
static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
|
||||
void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
|
||||
// write header
|
||||
gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
|
||||
gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
|
||||
|
|
|
@ -90,6 +90,7 @@ class Keys:
|
|||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
CONTEXT_LENGTH = "{arch}.context_length"
|
||||
EMBEDDING_LENGTH = "{arch}.embedding_length"
|
||||
FEATURES_LENGTH = "{arch}.features_length"
|
||||
BLOCK_COUNT = "{arch}.block_count"
|
||||
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
|
||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||
|
@ -122,6 +123,8 @@ class Keys:
|
|||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
|
@ -155,6 +158,14 @@ class Keys:
|
|||
class WKV:
|
||||
HEAD_SIZE = "{arch}.wkv.head_size"
|
||||
|
||||
class PosNet:
|
||||
EMBEDDING_LENGTH = "{arch}.posnet.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.posnet.block_count"
|
||||
|
||||
class ConvNext:
|
||||
EMBEDDING_LENGTH = "{arch}.convnext.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.convnext.block_count"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
PRE = "tokenizer.ggml.pre"
|
||||
|
@ -209,58 +220,59 @@ class GGUFType:
|
|||
|
||||
|
||||
class MODEL_ARCH(IntEnum):
|
||||
LLAMA = auto()
|
||||
FALCON = auto()
|
||||
BAICHUAN = auto()
|
||||
GROK = auto()
|
||||
GPT2 = auto()
|
||||
GPTJ = auto()
|
||||
GPTNEOX = auto()
|
||||
MPT = auto()
|
||||
STARCODER = auto()
|
||||
REFACT = auto()
|
||||
BERT = auto()
|
||||
NOMIC_BERT = auto()
|
||||
JINA_BERT_V2 = auto()
|
||||
BLOOM = auto()
|
||||
STABLELM = auto()
|
||||
QWEN = auto()
|
||||
QWEN2 = auto()
|
||||
QWEN2MOE = auto()
|
||||
QWEN2VL = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PLAMO = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
INTERNLM2 = auto()
|
||||
MINICPM = auto()
|
||||
MINICPM3 = auto()
|
||||
GEMMA = auto()
|
||||
GEMMA2 = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
OLMO2 = auto()
|
||||
OLMOE = auto()
|
||||
OPENELM = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
CHATGLM = auto()
|
||||
BITNET = auto()
|
||||
T5 = auto()
|
||||
T5ENCODER = auto()
|
||||
JAIS = auto()
|
||||
NEMOTRON = auto()
|
||||
EXAONE = auto()
|
||||
GRANITE = auto()
|
||||
GRANITE_MOE = auto()
|
||||
CHAMELEON = auto()
|
||||
LLAMA = auto()
|
||||
FALCON = auto()
|
||||
BAICHUAN = auto()
|
||||
GROK = auto()
|
||||
GPT2 = auto()
|
||||
GPTJ = auto()
|
||||
GPTNEOX = auto()
|
||||
MPT = auto()
|
||||
STARCODER = auto()
|
||||
REFACT = auto()
|
||||
BERT = auto()
|
||||
NOMIC_BERT = auto()
|
||||
JINA_BERT_V2 = auto()
|
||||
BLOOM = auto()
|
||||
STABLELM = auto()
|
||||
QWEN = auto()
|
||||
QWEN2 = auto()
|
||||
QWEN2MOE = auto()
|
||||
QWEN2VL = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PLAMO = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
INTERNLM2 = auto()
|
||||
MINICPM = auto()
|
||||
MINICPM3 = auto()
|
||||
GEMMA = auto()
|
||||
GEMMA2 = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
OLMO2 = auto()
|
||||
OLMOE = auto()
|
||||
OPENELM = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
CHATGLM = auto()
|
||||
BITNET = auto()
|
||||
T5 = auto()
|
||||
T5ENCODER = auto()
|
||||
JAIS = auto()
|
||||
NEMOTRON = auto()
|
||||
EXAONE = auto()
|
||||
GRANITE = auto()
|
||||
GRANITE_MOE = auto()
|
||||
CHAMELEON = auto()
|
||||
WAVTOKENIZER_DEC = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
|
@ -370,61 +382,78 @@ class MODEL_TENSOR(IntEnum):
|
|||
ENC_OUTPUT_NORM = auto()
|
||||
CLS = auto() # classifier
|
||||
CLS_OUT = auto() # classifier output projection
|
||||
CONV1D = auto()
|
||||
CONVNEXT_DW = auto()
|
||||
CONVNEXT_NORM = auto()
|
||||
CONVNEXT_PW1 = auto()
|
||||
CONVNEXT_PW2 = auto()
|
||||
CONVNEXT_GAMMA = auto()
|
||||
POSNET_CONV1 = auto()
|
||||
POSNET_CONV2 = auto()
|
||||
POSNET_NORM = auto()
|
||||
POSNET_NORM1 = auto()
|
||||
POSNET_NORM2 = auto()
|
||||
POSNET_ATTN_NORM = auto()
|
||||
POSNET_ATTN_Q = auto()
|
||||
POSNET_ATTN_K = auto()
|
||||
POSNET_ATTN_V = auto()
|
||||
POSNET_ATTN_OUT = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.LLAMA: "llama",
|
||||
MODEL_ARCH.FALCON: "falcon",
|
||||
MODEL_ARCH.BAICHUAN: "baichuan",
|
||||
MODEL_ARCH.GROK: "grok",
|
||||
MODEL_ARCH.GPT2: "gpt2",
|
||||
MODEL_ARCH.GPTJ: "gptj",
|
||||
MODEL_ARCH.GPTNEOX: "gptneox",
|
||||
MODEL_ARCH.MPT: "mpt",
|
||||
MODEL_ARCH.STARCODER: "starcoder",
|
||||
MODEL_ARCH.REFACT: "refact",
|
||||
MODEL_ARCH.BERT: "bert",
|
||||
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
|
||||
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
|
||||
MODEL_ARCH.BLOOM: "bloom",
|
||||
MODEL_ARCH.STABLELM: "stablelm",
|
||||
MODEL_ARCH.QWEN: "qwen",
|
||||
MODEL_ARCH.QWEN2: "qwen2",
|
||||
MODEL_ARCH.QWEN2MOE: "qwen2moe",
|
||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
MODEL_ARCH.CODESHELL: "codeshell",
|
||||
MODEL_ARCH.ORION: "orion",
|
||||
MODEL_ARCH.INTERNLM2: "internlm2",
|
||||
MODEL_ARCH.MINICPM: "minicpm",
|
||||
MODEL_ARCH.MINICPM3: "minicpm3",
|
||||
MODEL_ARCH.GEMMA: "gemma",
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
MODEL_ARCH.DBRX: "dbrx",
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.OLMO2: "olmo2",
|
||||
MODEL_ARCH.OLMOE: "olmoe",
|
||||
MODEL_ARCH.OPENELM: "openelm",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK: "deepseek",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.CHATGLM: "chatglm",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.T5ENCODER: "t5encoder",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.EXAONE: "exaone",
|
||||
MODEL_ARCH.GRANITE: "granite",
|
||||
MODEL_ARCH.GRANITE_MOE: "granitemoe",
|
||||
MODEL_ARCH.CHAMELEON: "chameleon",
|
||||
MODEL_ARCH.LLAMA: "llama",
|
||||
MODEL_ARCH.FALCON: "falcon",
|
||||
MODEL_ARCH.BAICHUAN: "baichuan",
|
||||
MODEL_ARCH.GROK: "grok",
|
||||
MODEL_ARCH.GPT2: "gpt2",
|
||||
MODEL_ARCH.GPTJ: "gptj",
|
||||
MODEL_ARCH.GPTNEOX: "gptneox",
|
||||
MODEL_ARCH.MPT: "mpt",
|
||||
MODEL_ARCH.STARCODER: "starcoder",
|
||||
MODEL_ARCH.REFACT: "refact",
|
||||
MODEL_ARCH.BERT: "bert",
|
||||
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
|
||||
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
|
||||
MODEL_ARCH.BLOOM: "bloom",
|
||||
MODEL_ARCH.STABLELM: "stablelm",
|
||||
MODEL_ARCH.QWEN: "qwen",
|
||||
MODEL_ARCH.QWEN2: "qwen2",
|
||||
MODEL_ARCH.QWEN2MOE: "qwen2moe",
|
||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
MODEL_ARCH.CODESHELL: "codeshell",
|
||||
MODEL_ARCH.ORION: "orion",
|
||||
MODEL_ARCH.INTERNLM2: "internlm2",
|
||||
MODEL_ARCH.MINICPM: "minicpm",
|
||||
MODEL_ARCH.MINICPM3: "minicpm3",
|
||||
MODEL_ARCH.GEMMA: "gemma",
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
MODEL_ARCH.DBRX: "dbrx",
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.OLMO2: "olmo2",
|
||||
MODEL_ARCH.OLMOE: "olmoe",
|
||||
MODEL_ARCH.OPENELM: "openelm",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK: "deepseek",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.CHATGLM: "chatglm",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.T5ENCODER: "t5encoder",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.EXAONE: "exaone",
|
||||
MODEL_ARCH.GRANITE: "granite",
|
||||
MODEL_ARCH.GRANITE_MOE: "granitemoe",
|
||||
MODEL_ARCH.CHAMELEON: "chameleon",
|
||||
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -534,6 +563,22 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
|
||||
MODEL_TENSOR.CLS: "cls",
|
||||
MODEL_TENSOR.CLS_OUT: "cls.output",
|
||||
MODEL_TENSOR.CONV1D: "conv1d",
|
||||
MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw",
|
||||
MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm",
|
||||
MODEL_TENSOR.CONVNEXT_PW1: "convnext.{bid}.pw1",
|
||||
MODEL_TENSOR.CONVNEXT_PW2: "convnext.{bid}.pw2",
|
||||
MODEL_TENSOR.CONVNEXT_GAMMA: "convnext.{bid}.gamma",
|
||||
MODEL_TENSOR.POSNET_CONV1: "posnet.{bid}.conv1",
|
||||
MODEL_TENSOR.POSNET_CONV2: "posnet.{bid}.conv2",
|
||||
MODEL_TENSOR.POSNET_NORM: "posnet.{bid}.norm",
|
||||
MODEL_TENSOR.POSNET_NORM1: "posnet.{bid}.norm1",
|
||||
MODEL_TENSOR.POSNET_NORM2: "posnet.{bid}.norm2",
|
||||
MODEL_TENSOR.POSNET_ATTN_NORM: "posnet.{bid}.attn_norm",
|
||||
MODEL_TENSOR.POSNET_ATTN_Q: "posnet.{bid}.attn_q",
|
||||
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
|
||||
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -1372,6 +1417,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.WAVTOKENIZER_DEC: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.CONV1D,
|
||||
MODEL_TENSOR.CONVNEXT_DW,
|
||||
MODEL_TENSOR.CONVNEXT_NORM,
|
||||
MODEL_TENSOR.CONVNEXT_PW1,
|
||||
MODEL_TENSOR.CONVNEXT_PW2,
|
||||
MODEL_TENSOR.CONVNEXT_GAMMA,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.POSNET_CONV1,
|
||||
MODEL_TENSOR.POSNET_CONV2,
|
||||
MODEL_TENSOR.POSNET_NORM,
|
||||
MODEL_TENSOR.POSNET_NORM1,
|
||||
MODEL_TENSOR.POSNET_NORM2,
|
||||
MODEL_TENSOR.POSNET_ATTN_NORM,
|
||||
MODEL_TENSOR.POSNET_ATTN_Q,
|
||||
MODEL_TENSOR.POSNET_ATTN_K,
|
||||
MODEL_TENSOR.POSNET_ATTN_V,
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
@ -631,6 +631,21 @@ class GGUFWriter:
|
|||
def add_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_features_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_posnet_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_posnet_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_convnext_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_convnext_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
|
@ -727,6 +742,12 @@ class GGUFWriter:
|
|||
def add_layer_norm_rms_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
|
||||
|
||||
def add_group_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
def add_group_norm_groups(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
|
||||
|
||||
def add_causal_attention(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ class TensorNameMap:
|
|||
"emb_ln", # nomic-bert
|
||||
"transformer.norm", # openelm
|
||||
"rwkv.blocks.0.pre_ln", # rwkv
|
||||
"backbone.norm", # wavtokenizer
|
||||
),
|
||||
|
||||
# Position embeddings
|
||||
|
@ -60,6 +61,7 @@ class TensorNameMap:
|
|||
"lm_head.linear", # phi2
|
||||
"output_layer", # chatglm
|
||||
"head", # rwkv
|
||||
"head.out", # wavtokenizer
|
||||
),
|
||||
|
||||
# Output norm
|
||||
|
@ -80,6 +82,7 @@ class TensorNameMap:
|
|||
"transformer.norm", # openelm
|
||||
"model.norm", # nemotron
|
||||
"rwkv.ln_out", # rwkv
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
|
@ -90,6 +93,10 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG: (),
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
|
||||
|
||||
MODEL_TENSOR.CONV1D: (
|
||||
"backbone.embed", # roberta
|
||||
),
|
||||
}
|
||||
|
||||
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||
|
@ -681,6 +688,8 @@ class TensorNameMap:
|
|||
"encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
|
||||
),
|
||||
|
||||
############################################################################
|
||||
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
||||
"encoder.final_layer_norm", # t5
|
||||
),
|
||||
|
@ -693,6 +702,67 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.CLS_OUT: (
|
||||
"classifier.out_proj", # roberta
|
||||
),
|
||||
#############################################################################
|
||||
|
||||
MODEL_TENSOR.CONVNEXT_DW: (
|
||||
"backbone.convnext.{bid}.dwconv", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CONVNEXT_NORM: (
|
||||
"backbone.convnext.{bid}.norm", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CONVNEXT_PW1: (
|
||||
"backbone.convnext.{bid}.pwconv1", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CONVNEXT_PW2: (
|
||||
"backbone.convnext.{bid}.pwconv2", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CONVNEXT_GAMMA: (
|
||||
"backbone.convnext.{bid}.gamma", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_CONV1: (
|
||||
"backbone.posnet.{bid}.conv1", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_CONV2: (
|
||||
"backbone.posnet.{bid}.conv2", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_NORM: (
|
||||
"backbone.posnet.{bid}.norm", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_NORM1: (
|
||||
"backbone.posnet.{bid}.norm1", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_NORM2: (
|
||||
"backbone.posnet.{bid}.norm2", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_ATTN_NORM: (
|
||||
"backbone.posnet.{bid}.norm", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_ATTN_Q: (
|
||||
"backbone.posnet.{bid}.q", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_ATTN_K: (
|
||||
"backbone.posnet.{bid}.k", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_ATTN_V: (
|
||||
"backbone.posnet.{bid}.v", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: (
|
||||
"backbone.posnet.{bid}.proj_out", # wavtokenizer
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
|
|
@ -136,7 +136,7 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
|
|||
logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
|
||||
|
||||
sum_diff_bits = np.sum(diff_bits)
|
||||
logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)")
|
||||
logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)")
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -482,9 +482,6 @@ extern "C" {
|
|||
// Returns the total number of parameters in the model
|
||||
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||
|
||||
// Get a llama model tensor
|
||||
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
|
||||
|
||||
// Returns true if the model contains an encoder that requires llama_encode() call
|
||||
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
|
||||
|
||||
|
@ -1139,16 +1136,12 @@ extern "C" {
|
|||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
|
||||
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||
int32_t n_vocab, // llama_n_vocab()
|
||||
llama_token special_eos_id, // llama_token_eos()
|
||||
llama_token linefeed_id, // llama_token_nl()
|
||||
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat, // 1.0 = disabled
|
||||
float penalty_freq, // 0.0 = disabled
|
||||
float penalty_present, // 0.0 = disabled
|
||||
bool penalize_nl, // consider newlines as a repeatable token
|
||||
bool ignore_eos); // ignore the end-of-sequence token
|
||||
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat, // 1.0 = disabled
|
||||
float penalty_freq, // 0.0 = disabled
|
||||
float penalty_present); // 0.0 = disabled
|
||||
|
||||
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
||||
|
|
|
@ -1 +1 @@
|
|||
74d66b63eaf207a24f3e93bb922aba131cbf2906
|
||||
e6d93f40dffe8733d5d72f1d8fa6b3ca27ae899f
|
||||
|
|
|
@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
|
|||
// penalties
|
||||
|
||||
struct llama_sampler_penalties {
|
||||
const int32_t n_vocab;
|
||||
const llama_token special_eos_id;
|
||||
const llama_token linefeed_id;
|
||||
|
||||
const int32_t penalty_last_n;
|
||||
const float penalty_repeat;
|
||||
const float penalty_freq;
|
||||
const float penalty_present;
|
||||
|
||||
const bool penalize_nl;
|
||||
const bool ignore_eos;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
||||
// a frequency map to count token occurrences
|
||||
std::unordered_map<llama_token, int> token_count;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
|
|||
return;
|
||||
}
|
||||
|
||||
ctx->token_count[token]++;
|
||||
|
||||
// if the ring buffer is full, remove the oldest token
|
||||
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
|
||||
const auto old = ctx->prev.front();
|
||||
|
||||
ctx->token_count[old]--;
|
||||
if (ctx->token_count[old] == 0) {
|
||||
ctx->token_count.erase(old);
|
||||
}
|
||||
}
|
||||
|
||||
ctx->prev.push_back(token);
|
||||
|
||||
#if 0
|
||||
// sanity check
|
||||
std::unordered_map<llama_token, int> tmp;
|
||||
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
||||
tmp[ctx->prev.rat(i)]++;
|
||||
}
|
||||
|
||||
assert(ctx->token_count == tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
||||
|
||||
if (ctx->ignore_eos) {
|
||||
assert(ctx->special_eos_id >= 0);
|
||||
|
||||
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
||||
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
|
||||
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
|
||||
} else {
|
||||
// else, search for the special EOS token
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].id == ctx->special_eos_id) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ((ctx->penalty_last_n == 0) ||
|
||||
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool nl_found = false;
|
||||
size_t nl_idx = 0;
|
||||
float nl_logit = -INFINITY;
|
||||
if (!ctx->penalize_nl) {
|
||||
assert(ctx->linefeed_id >= 0);
|
||||
|
||||
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
||||
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
|
||||
nl_found = true;
|
||||
nl_idx = ctx->linefeed_id;
|
||||
nl_logit = cur_p->data[ctx->linefeed_id].logit;
|
||||
} else {
|
||||
// else, search for the linefeed token
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].id == ctx->linefeed_id) {
|
||||
nl_found = true;
|
||||
nl_idx = i;
|
||||
nl_logit = cur_p->data[i].logit;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
// TODO: optimize this by maintaining the token count in the sampler context
|
||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
||||
llama_token_cnt token_count;
|
||||
|
||||
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
||||
token_count[ctx->prev.rat(i)]++;
|
||||
}
|
||||
|
||||
// Apply frequency and presence penalties to the cur_p
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
const auto token_iter = token_count.find(cur_p->data[i].id);
|
||||
if (token_iter == token_count.end()) {
|
||||
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
|
||||
if (token_iter == ctx->token_count.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int count = token_iter->second;
|
||||
|
||||
assert(count > 0 && count <= ctx->penalty_last_n);
|
||||
|
||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||
if (cur_p->data[i].logit <= 0) {
|
||||
|
@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
|||
}
|
||||
|
||||
cur_p->sorted = false;
|
||||
|
||||
if (!ctx->penalize_nl && nl_found) {
|
||||
// restore the logit of the newline token if it was penalized
|
||||
cur_p->data[nl_idx].logit = nl_logit;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
||||
ctx->prev.clear();
|
||||
ctx->token_count.clear();
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_penalties(
|
||||
ctx->n_vocab,
|
||||
ctx->special_eos_id,
|
||||
ctx->linefeed_id,
|
||||
ctx->penalty_last_n,
|
||||
ctx->penalty_repeat,
|
||||
ctx->penalty_freq,
|
||||
ctx->penalty_present,
|
||||
ctx->penalize_nl,
|
||||
ctx->ignore_eos);
|
||||
ctx->penalty_present);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
|
@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_penalties(
|
||||
int32_t n_vocab,
|
||||
llama_token special_eos_id,
|
||||
llama_token linefeed_id,
|
||||
int32_t penalty_last_n,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present,
|
||||
bool penalize_nl,
|
||||
bool ignore_eos) {
|
||||
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
||||
penalize_nl = true;
|
||||
}
|
||||
|
||||
if (special_eos_id == LLAMA_TOKEN_NULL) {
|
||||
ignore_eos = false;
|
||||
}
|
||||
|
||||
float penalty_present) {
|
||||
penalty_last_n = std::max(penalty_last_n, 0);
|
||||
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_penalties_i,
|
||||
/* .ctx = */ new llama_sampler_penalties {
|
||||
/* .n_vocab = */ n_vocab,
|
||||
/* .special_eos_id = */ special_eos_id,
|
||||
/* .linefeed_id = */ linefeed_id,
|
||||
/* .penalty_last_n = */ penalty_last_n,
|
||||
/* .penalty_repeat = */ penalty_repeat,
|
||||
/* .penalty_freq = */ penalty_freq,
|
||||
/* .penalty_present = */ penalty_present,
|
||||
/* .penalize_nl = */ penalize_nl,
|
||||
/* .ignore_eos = */ ignore_eos,
|
||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||
/* .token_count = */ {},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
|
|||
if (word.find(str) != std::string::npos) {
|
||||
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||
} else {
|
||||
size_t word_len = word.size(), str_len = str.size();
|
||||
size_t word_len = word.size();
|
||||
size_t str_len = str.size();
|
||||
size_t pos = -1;
|
||||
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
||||
bool match = true;
|
||||
|
|
|
@ -738,7 +738,7 @@ struct llm_tokenizer_wpm_session {
|
|||
std::vector<std::string> words(1, "");
|
||||
|
||||
for (const uint32_t cpt : cpts_nfd) {
|
||||
const auto flags = unicode_cpt_flags(cpt);
|
||||
const auto flags = unicode_cpt_flags_from_cpt(cpt);
|
||||
|
||||
if (flags.is_whitespace) {
|
||||
if (words.back().size()) { // finish previous word if any
|
||||
|
@ -1867,6 +1867,10 @@ int32_t llama_detokenize_impl(
|
|||
int32_t text_len_max,
|
||||
bool remove_special,
|
||||
bool unparse_special) {
|
||||
if (vocab.type == LLAMA_VOCAB_TYPE_NONE) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
|
||||
|
||||
int32_t avail = text_len_max;
|
||||
|
|
977
src/llama.cpp
977
src/llama.cpp
File diff suppressed because it is too large
Load diff
102
src/unicode.cpp
102
src/unicode.cpp
|
@ -71,15 +71,15 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
|
|||
throw std::invalid_argument("failed to convert utf8 to codepoint");
|
||||
}
|
||||
|
||||
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
|
||||
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
|
||||
// std::vector<uint16_t> result;
|
||||
// if (/* 0x0000 <= cp && */ cp <= 0xffff) {
|
||||
// result.emplace_back(cp);
|
||||
// if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
|
||||
// result.emplace_back(cpt);
|
||||
// return result;
|
||||
// }
|
||||
// if (0x10000 <= cp && cp <= 0x10ffff) {
|
||||
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
|
||||
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
|
||||
// if (0x10000 <= cpt && cpt <= 0x10ffff) {
|
||||
// result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
|
||||
// result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
|
||||
// return result;
|
||||
// }
|
||||
// throw std::invalid_argument("failed to convert codepoint to utf16");
|
||||
|
@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
|
|||
// return result;
|
||||
//}
|
||||
|
||||
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
|
||||
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
|
||||
static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
|
||||
std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
|
||||
|
||||
assert (unicode_ranges_flags.begin()[0].first == 0);
|
||||
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
|
||||
|
@ -253,8 +253,8 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
|
@ -371,8 +371,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
|
@ -572,29 +572,29 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
|||
// interface
|
||||
//
|
||||
|
||||
std::string unicode_cpt_to_utf8(uint32_t cp) {
|
||||
std::string unicode_cpt_to_utf8(uint32_t cpt) {
|
||||
std::string result;
|
||||
|
||||
if (/* 0x00 <= cp && */ cp <= 0x7f) {
|
||||
result.push_back(cp);
|
||||
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
|
||||
result.push_back(cpt);
|
||||
return result;
|
||||
}
|
||||
if (0x80 <= cp && cp <= 0x7ff) {
|
||||
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
|
||||
result.push_back(0x80 | (cp & 0x3f));
|
||||
if (0x80 <= cpt && cpt <= 0x7ff) {
|
||||
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
if (0x800 <= cp && cp <= 0xffff) {
|
||||
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
|
||||
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
||||
result.push_back(0x80 | (cp & 0x3f));
|
||||
if (0x800 <= cpt && cpt <= 0xffff) {
|
||||
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
|
||||
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
if (0x10000 <= cp && cp <= 0x10ffff) {
|
||||
result.push_back(0xf0 | ((cp >> 18) & 0x07));
|
||||
result.push_back(0x80 | ((cp >> 12) & 0x3f));
|
||||
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
||||
result.push_back(0x80 | (cp & 0x3f));
|
||||
if (0x10000 <= cpt && cpt <= 0x10ffff) {
|
||||
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
|
||||
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
|
||||
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
|
||||
result.push_back(0x80 | (cpt & 0x3f));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -624,19 +624,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
|||
return result;
|
||||
}
|
||||
|
||||
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
|
||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||
unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
|
||||
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
|
||||
static const auto cpt_flags = unicode_cpt_flags_array();
|
||||
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
|
||||
return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
|
||||
}
|
||||
|
||||
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
|
||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
|
||||
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
|
||||
if (utf8.empty()) {
|
||||
return undef; // undefined
|
||||
}
|
||||
size_t offset = 0;
|
||||
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
|
||||
return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
|
||||
}
|
||||
|
||||
std::string unicode_byte_to_utf8(uint8_t byte) {
|
||||
|
@ -649,41 +649,41 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
|
|||
return map.at(utf8);
|
||||
}
|
||||
|
||||
uint32_t unicode_tolower(uint32_t cp) {
|
||||
uint32_t unicode_tolower(uint32_t cpt) {
|
||||
// binary search
|
||||
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp,
|
||||
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
|
||||
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
|
||||
return pair.first < value;
|
||||
});
|
||||
if (it != unicode_map_lowercase.end() && it->first == cp) {
|
||||
if (it != unicode_map_lowercase.end() && it->first == cpt) {
|
||||
return it->second;
|
||||
}
|
||||
return cp; // Return the original code point if no lowercase mapping is found
|
||||
return cpt; // Return the original code point if no lowercase mapping is found
|
||||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", codepoint_flags::NUMBER },
|
||||
{ "\\p{L}", codepoint_flags::LETTER },
|
||||
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
||||
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
||||
{ "\\p{L}", unicode_cpt_flags::LETTER },
|
||||
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ codepoint_flags::NUMBER, 0xD1 },
|
||||
{ codepoint_flags::LETTER, 0xD2 },
|
||||
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
||||
{ unicode_cpt_flags::NUMBER, 0xD1 },
|
||||
{ unicode_cpt_flags::LETTER, 0xD2 },
|
||||
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
bool need_collapse = false;
|
||||
for (auto & regex_expr : regex_exprs) {
|
||||
for (const auto & regex_expr : regex_exprs) {
|
||||
// search for unicode categories
|
||||
for (const auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
|
@ -709,7 +709,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto flags = unicode_cpt_flags(cpts[i]);
|
||||
const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
|
||||
|
||||
if (flags.is_whitespace) {
|
||||
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
||||
|
@ -725,7 +725,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
|
||||
std::vector<size_t> bpe_offsets = { cpts.size() };
|
||||
|
||||
for (auto & regex_expr : regex_exprs) {
|
||||
for (const auto & regex_expr : regex_exprs) {
|
||||
// first, see if we have an efficient custom regex implementation
|
||||
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
|
||||
|
||||
|
@ -739,7 +739,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
||||
// with the corresponding collapsed representation
|
||||
bool use_collapsed = false;
|
||||
for (auto & ucat : k_ucat_enum) {
|
||||
for (const auto & ucat : k_ucat_enum) {
|
||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||
use_collapsed = true;
|
||||
break;
|
||||
|
@ -805,7 +805,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||
std::wstring wtext(cpts.begin(), cpts.end());
|
||||
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
|
||||
if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
|
||||
wtext[i] = 0x0B;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// TODO: prefix all symbols with "llama_"
|
||||
|
||||
struct codepoint_flags {
|
||||
struct unicode_cpt_flags {
|
||||
enum {
|
||||
UNDEFINED = 0x0001,
|
||||
NUMBER = 0x0002, // regex: \p{N}
|
||||
|
@ -35,7 +33,7 @@ struct codepoint_flags {
|
|||
uint16_t is_nfd : 1;
|
||||
|
||||
// decode from uint16
|
||||
inline codepoint_flags(const uint16_t flags=0) {
|
||||
inline unicode_cpt_flags(const uint16_t flags = 0) {
|
||||
*reinterpret_cast<uint16_t*>(this) = flags;
|
||||
}
|
||||
|
||||
|
@ -50,18 +48,19 @@ struct codepoint_flags {
|
|||
|
||||
size_t unicode_len_utf8(char src);
|
||||
|
||||
std::string unicode_cpt_to_utf8(uint32_t cp);
|
||||
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
|
||||
std::string unicode_cpt_to_utf8 (uint32_t cpt);
|
||||
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
|
||||
|
||||
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
|
||||
|
||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
|
||||
|
||||
codepoint_flags unicode_cpt_flags(const uint32_t cp);
|
||||
codepoint_flags unicode_cpt_flags(const std::string & utf8);
|
||||
unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
|
||||
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
|
||||
|
||||
std::string unicode_byte_to_utf8(uint8_t byte);
|
||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||
|
||||
uint32_t unicode_tolower(uint32_t cp);
|
||||
uint32_t unicode_tolower(uint32_t cpt);
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|
||||
|
|
|
@ -129,6 +129,7 @@ llama_target_and_test(test-arg-parser.cpp)
|
|||
llama_target_and_test(test-chat-template.cpp)
|
||||
|
||||
# llama_target_and_test(test-opt.cpp) # SLOW
|
||||
llama_target_and_test(test-gguf.cpp)
|
||||
llama_target_and_test(test-backend-ops.cpp)
|
||||
|
||||
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
|
||||
|
|
|
@ -3549,8 +3549,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
|
||||
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
for (ggml_type type_dst : all_types) {
|
||||
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
|
||||
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
|
||||
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
|
||||
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
|
||||
}
|
||||
}
|
||||
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
|
|
1303
tests/test-gguf.cpp
Normal file
1303
tests/test-gguf.cpp
Normal file
File diff suppressed because it is too large
Load diff
|
@ -145,7 +145,7 @@ static void test_penalties(
|
|||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
const size_t n_vocab = probs.size();
|
||||
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
|
||||
auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||
|
||||
for (size_t i = 0; i < last_tokens.size(); i++) {
|
||||
llama_sampler_accept(sampler, last_tokens[i]);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue