diff --git a/Makefile b/Makefile index 3459c5470..7d7391d47 100644 --- a/Makefile +++ b/Makefile @@ -201,6 +201,10 @@ ifdef LLAMA_SERVER_VERBOSE MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) endif +ifdef LLAMA_SERVER_SSL + MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT + MK_LDFLAGS += -lssl -lcrypto +endif ifdef LLAMA_CODE_COVERAGE MK_CXXFLAGS += -fprofile-arcs -ftest-coverage -dumpbase '' diff --git a/README.md b/README.md index f754022de..d7dba73e6 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ### Recent API changes +- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_max_seq()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328 - [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796 - [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849 @@ -110,6 +111,7 @@ Typically finetunes of the base models below are supported as well. - [x] [InternLM2](https://huggingface.co/models?search=internlm2) - [x] [CodeShell](https://github.com/WisdomShell/codeshell) - [x] [Gemma](https://ai.google.dev/gemma) +- [x] [Mamba](https://github.com/state-spaces/mamba) **Multimodal models:** diff --git a/common/common.cpp b/common/common.cpp index c244db644..d7f650ef4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1288,6 +1288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_ctx = params.n_ctx; cparams.n_batch = params.n_batch; + cparams.n_parallel = params.n_parallel; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.seed = params.seed; diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index f6369af38..5eee32016 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1847,6 +1847,124 @@ class StarCoder2Model(Model): model_arch = gguf.MODEL_ARCH.STARCODER2 +@Model.register("MambaForCausalLM", "MambaLMHeadModel") +class MambaModel(Model): + model_arch = gguf.MODEL_ARCH.MAMBA + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 8 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 8) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf" + print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'") + neox_reader = gguf.GGUFReader(tokenizer_path, "r") + + field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL) + self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1])) + field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST) + self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE) + self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES) + self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID) + self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID) + self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID) + self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0]) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16 + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + assert d_inner == 2 * d_model + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(dt_rank) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def write_tensors(self): + block_count = self.hparams["n_layer"] + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + tok_embd = None + tok_embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD] + ".weight" + output_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight" + + for name, data_torch in self.get_tensors(): + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + if name.endswith(".A_log"): + print("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + # assuming token_embd.weight is seen before output.weight + if tok_embd is not None and new_name == output_name: + if torch.equal(tok_embd, data_torch): + print(f"{output_name} is equivalent to {tok_embd_name}, omitting") + continue + if new_name == tok_embd_name: + tok_embd = data_torch + + data = data_torch.squeeze().numpy() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert big float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 19aff18ae..dff6c68ec 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -105,6 +105,9 @@ int main(int argc, char ** argv) { ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + // ensure enough sequences are available + ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end()); + llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { @@ -174,10 +177,10 @@ int main(int argc, char ** argv) { llama_batch_clear(batch); - const int n_tokens = is_pp_shared ? pp : pl*pp; - - for (int i = 0; i < n_tokens; ++i) { - llama_batch_add(batch, 0, i, { 0 }, false); + for (int i = 0; i < pp; ++i) { + for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { + llama_batch_add(batch, 0, i, { j }, false); + } } batch.logits[batch.n_tokens - 1] = true; @@ -192,7 +195,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, pp); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } } diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 9be7eb56b..dde4d5a06 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -80,6 +80,7 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_len, n_parallel); + ctx_params.n_parallel = n_parallel; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; @@ -132,7 +133,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them for (int32_t i = 1; i < n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } if (n_parallel > 1) { diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7d11fcd59..a2ef0fb03 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -107,6 +107,9 @@ int main(int argc, char ** argv) { // number of simultaneous "clients" to simulate const int32_t n_clients = params.n_parallel; + // dedicate one sequence to the system prompt + params.n_parallel += 1; + // requests to simulate const int32_t n_seq = params.n_sequences; @@ -196,8 +199,8 @@ int main(int argc, char ** argv) { } // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < n_clients; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); + for (int32_t i = 1; i <= n_clients; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("\n"); @@ -221,15 +224,17 @@ int main(int argc, char ** argv) { client.i_batch = batch.n_tokens; - llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true); + llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); client.n_decoded += 1; } if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache - for (int i = 0; i < n_clients; ++i) { - llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); + for (int i = 1; i <= n_clients; ++i) { + llama_kv_cache_seq_rm(ctx, i, -1, -1); + // but keep the system prompt + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("%s: clearing the KV cache\n", __func__); @@ -255,7 +260,7 @@ int main(int argc, char ** argv) { tokens_prompt = ::llama_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false); + llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); } // extract the logits only for the last token @@ -366,7 +371,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9ec989389..52789ee63 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -809,7 +809,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const int n_batch = params.n_batch; const int max_tasks_per_batch = 32; - const int max_seq = 4*max_tasks_per_batch; + const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); @@ -1086,7 +1086,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { const int n_batch = params.n_batch; const int max_tasks_per_batch = 128; - const int max_seq = 2*max_tasks_per_batch; + const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); @@ -1438,7 +1438,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params const int n_batch = params.n_batch; const int max_tasks_per_batch = 32; - const int max_seq = 4*max_tasks_per_batch; + const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); @@ -1815,6 +1815,9 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); + // load the model and apply lora adapter, if any std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) { diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index c21eba634..f94de1e99 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) +option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h) install(TARGETS ${TARGET} RUNTIME) @@ -7,6 +8,11 @@ target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ ) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) +if (LLAMA_SERVER_SSL) + find_package(OpenSSL REQUIRED) + target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT) +endif() if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() diff --git a/examples/server/README.md b/examples/server/README.md index d0ab9709d..2602613b2 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -59,6 +59,10 @@ see https://github.com/ggerganov/llama.cpp/issues/1437 - `--log-disable`: Output logs to stdout only, default: enabled. - `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json) +**If compiled with `LLAMA_SERVER_SSL=ON`** +- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key +- `--ssl-cert-file FNAME`: path to file a PEM-encoded SSL certificate + ## Build server is build alongside everything else from the root of the project @@ -75,6 +79,28 @@ server is build alongside everything else from the root of the project cmake --build . --config Release ``` +## Build with SSL + +server can also be built with SSL support using OpenSSL 3 + +- Using `make`: + + ```bash + # NOTE: For non-system openssl, use the following: + # CXXFLAGS="-I /path/to/openssl/include" + # LDFLAGS="-L /path/to/openssl/lib" + make LLAMA_SERVER_SSL=true server + ``` + +- Using `CMake`: + + ```bash + mkdir build + cd build + cmake .. -DLLAMA_SERVER_SSL=ON + make server + ``` + ## Quick Start To get started right away, run the following command, making sure to use the correct path for the model you have: diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2dd7dfca1..5460a7e7a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -29,6 +29,7 @@ #include #include #include +#include using json = nlohmann::json; @@ -120,6 +121,11 @@ struct server_params { std::vector api_keys; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ssl_key_file = ""; + std::string ssl_cert_file = ""; +#endif + bool slots_endpoint = true; bool metrics_endpoint = false; }; @@ -661,7 +667,11 @@ struct server_context { bool load_model(const gpt_params & params_) { params = params_; + // dedicate one sequence to the system prompt + params.n_parallel += 1; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + params.n_parallel -= 1; // but be sneaky about it if (model == nullptr) { LOG_ERROR("unable to load model", {{"model", params.model}}); return false; @@ -1021,8 +1031,8 @@ struct server_context { } // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + for (int32_t i = 1; i <= params.n_parallel; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } } @@ -1127,6 +1137,7 @@ struct server_context { LOG_VERBOSE("stopped by limit", { {"id_slot", slot.id}, + {"id_task", slot.id_task}, {"n_decoded", slot.n_decoded}, {"n_predict", slot.params.n_predict}, }); @@ -1140,6 +1151,8 @@ struct server_context { } LOG_VERBOSE("next token", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, {"token", result.tok}, {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, {"has_next_token", slot.has_next_token}, @@ -1318,7 +1331,7 @@ struct server_context { const int n_embd = llama_n_embd(model); for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { continue; } @@ -1652,8 +1665,8 @@ struct server_context { {"n_cache_tokens", slot.cache_tokens.size()} }); - llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -1685,7 +1698,7 @@ struct server_context { // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true); slot.n_past += 1; @@ -1765,6 +1778,15 @@ struct server_context { slot.n_past = 0; slot.n_prompt_tokens = prompt_tokens.size(); + LOG_VERBOSE("prompt tokenized", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_prompt_tokens", slot.n_prompt_tokens}, + {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, + }); + if (slot.embedding) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_batch) { @@ -1803,10 +1825,13 @@ struct server_context { slot.n_prompt_tokens = prompt_tokens.size(); LOG_VERBOSE("input truncated", { - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"n_prompt_tokens", slot.n_prompt_tokens}, + {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); @@ -1823,9 +1848,6 @@ struct server_context { // reuse any previously computed tokens that are common with the new prompt slot.n_past = common_part(slot.cache_tokens, prompt_tokens); - // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); - // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); @@ -1856,8 +1878,28 @@ struct server_context { } } - const int p0 = (int) system_tokens.size() + slot.n_past; - llama_kv_cache_seq_rm(ctx, slot.id, p0, -1); + // keep only the common part + int p0 = (int) system_tokens.size() + slot.n_past; + if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); + + p0 = (int) system_tokens.size(); + if (p0 != 0) { + // copy over the system prompt when there is one + llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); + } + + // there is no common part left (except for the system prompt) + slot.n_past = 0; + slot.n_past_se = 0; + slot.ga_i = 0; + // TODO: is the system prompt ever in the sampling context? + llama_sampling_reset(slot.ctx_sampling); + } + + // remove the non-common part from the cache + slot.cache_tokens.resize(slot.n_past); LOG_INFO("kv cache rm [p0, end)", { { "id_slot", slot.id }, @@ -1882,7 +1924,7 @@ struct server_context { } } - llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); + llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -1956,9 +1998,9 @@ struct server_context { LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); + llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); + llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); slot.n_past_se -= bd; @@ -2125,6 +2167,10 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + printf(" --ssl-key-file FNAME path to file a PEM-encoded SSL private key\n"); + printf(" --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate\n"); +#endif printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); @@ -2203,7 +2249,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, } } key_file.close(); - } else if (arg == "--timeout" || arg == "-to") { + + } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + else if (arg == "--ssl-key-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.ssl_key_file = argv[i]; + } else if (arg == "--ssl-cert-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.ssl_cert_file = argv[i]; + } +#endif + else if (arg == "--timeout" || arg == "-to") { if (++i >= argc) { invalid_param = true; break; @@ -2641,21 +2704,34 @@ int main(int argc, char ** argv) { {"system_info", llama_print_system_info()}, }); - httplib::Server svr; + std::unique_ptr svr; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (sparams.ssl_key_file != "" && sparams.ssl_cert_file != "") { + LOG_INFO("Running with SSL", {{"key", sparams.ssl_key_file}, {"cert", sparams.ssl_cert_file}}); + svr.reset( + new httplib::SSLServer(sparams.ssl_cert_file.c_str(), sparams.ssl_key_file.c_str()) + ); + } else { + LOG_INFO("Running without SSL", {}); + svr.reset(new httplib::Server()); + } +#else + svr.reset(new httplib::Server()); +#endif std::atomic state{SERVER_STATE_LOADING_MODEL}; - svr.set_default_headers({{"Server", "llama.cpp"}}); + svr->set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { + svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Headers", "*"); }); - svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) { + svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); switch (current_state) { case SERVER_STATE_READY: @@ -2711,7 +2787,7 @@ int main(int argc, char ** argv) { }); if (sparams.slots_endpoint) { - svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) { + svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue server_task task; task.id = ctx_server.queue_tasks.get_new_id(); @@ -2732,7 +2808,7 @@ int main(int argc, char ** argv) { } if (sparams.metrics_endpoint) { - svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) { + svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue server_task task; task.id = ctx_server.queue_tasks.get_new_id(); @@ -2829,9 +2905,9 @@ int main(int argc, char ** argv) { }); } - svr.set_logger(log_server_request); + svr->set_logger(log_server_request); - svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { const char fmt[] = "500 Internal Server Error\n%s"; char buf[BUFSIZ]; @@ -2847,7 +2923,7 @@ int main(int argc, char ** argv) { res.status = 500; }); - svr.set_error_handler([](const httplib::Request &, httplib::Response & res) { + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { if (res.status == 401) { res.set_content("Unauthorized", "text/plain; charset=utf-8"); } @@ -2860,16 +2936,16 @@ int main(int argc, char ** argv) { }); // set timeouts and change hostname and port - svr.set_read_timeout (sparams.read_timeout); - svr.set_write_timeout(sparams.write_timeout); + svr->set_read_timeout (sparams.read_timeout); + svr->set_write_timeout(sparams.write_timeout); - if (!svr.bind_to_port(sparams.hostname, sparams.port)) { + if (!svr->bind_to_port(sparams.hostname, sparams.port)) { fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); return 1; } // Set the base directory for serving static files - svr.set_base_dir(sparams.public_path); + svr->set_base_dir(sparams.public_path); std::unordered_map log_data; @@ -2930,30 +3006,30 @@ int main(int argc, char ** argv) { }; // this is only called if no index.html is found in the public --path - svr.Get("/", [](const httplib::Request &, httplib::Response & res) { + svr->Get("/", [](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html; charset=utf-8"); return false; }); // this is only called if no index.js is found in the public --path - svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) { + svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript; charset=utf-8"); return false; }); // this is only called if no index.html is found in the public --path - svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) { + svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); return false; }); // this is only called if no index.html is found in the public --path - svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) { + svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); return false; }); - svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", ctx_server.name_user.c_str() }, @@ -3053,11 +3129,11 @@ int main(int argc, char ** argv) { } }; - svr.Post("/completion", completions); // legacy - svr.Post("/completions", completions); - svr.Post("/v1/completions", completions); + svr->Post("/completion", completions); // legacy + svr->Post("/completions", completions); + svr->Post("/v1/completions", completions); - svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { + svr->Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json models = { @@ -3152,10 +3228,10 @@ int main(int argc, char ** argv) { } }; - svr.Post("/chat/completions", chat_completions); - svr.Post("/v1/chat/completions", chat_completions); + svr->Post("/chat/completions", chat_completions); + svr->Post("/v1/chat/completions", chat_completions); - svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3219,11 +3295,11 @@ int main(int argc, char ** argv) { } }); - svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { + svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { return res.set_content("", "application/json; charset=utf-8"); }); - svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); @@ -3235,7 +3311,7 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); @@ -3249,7 +3325,7 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!params.embedding) { res.status = 501; @@ -3280,7 +3356,7 @@ int main(int argc, char ** argv) { return res.set_content(result.data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!params.embedding) { res.status = 501; @@ -3351,13 +3427,13 @@ int main(int argc, char ** argv) { sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); } log_data["n_threads_http"] = std::to_string(sparams.n_threads_http); - svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; + svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; LOG_INFO("HTTP server listening", log_data); // run the HTTP server in a thread - see comment below std::thread t([&]() { - if (!svr.listen_after_bind()) { + if (!svr->listen_after_bind()) { state.store(SERVER_STATE_ERROR); return 1; } @@ -3398,7 +3474,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.start_loop(); - svr.stop(); + svr->stop(); t.join(); llama_backend_free(); diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index 066698c8e..a66fed626 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -6,8 +6,8 @@ Feature: Parallel Given a server listening on localhost:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And 42 as server seed - And 512 as batch size - And 64 KV cache size + And 128 as batch size + And 256 KV cache size And 2 slots And continuous batching Then the server is starting @@ -76,6 +76,7 @@ Feature: Parallel | disabled | 128 | | enabled | 64 | + Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969 Given a prompt: """ diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index 878ac1363..aa132fa34 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -10,11 +10,10 @@ Feature: llama.cpp server # KV Cache corresponds to the total amount of tokens # that can be stored across all independent sequences: #4130 # see --ctx-size and #5568 - And 32 KV cache size - And 512 as batch size - And 1 slots - And embeddings extraction - And 32 server max tokens to predict + And 256 KV cache size + And 32 as batch size + And 2 slots + And 64 server max tokens to predict And prometheus compatible metrics exposed Then the server is starting Then the server is healthy @@ -23,18 +22,35 @@ Feature: llama.cpp server Then the server is ready And all slots are idle + Scenario Outline: Completion Given a prompt And max tokens to predict And a completion request with no api error Then tokens are predicted matching + And the completion is truncated + And prompt tokens are processed And prometheus metrics are exposed And metric llamacpp:tokens_predicted is Examples: Prompts - | prompt | n_predict | re_content | n_predicted | - | I believe the meaning of life is | 8 | (read\|going)+ | 8 | - | Write a joke about AI | 64 | (park\|friends\|scared\|always)+ | 32 | + | prompt | n_predict | re_content | n_prompt | n_predicted | truncated | + | I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not | + | Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids)+ | 46 | 64 | not | + + Scenario: Completion prompt truncated + Given a prompt: + """ + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. + Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. + """ + And a completion request with no api error + Then 64 tokens are predicted matching fun|Annaks|popcorns + And the completion is truncated + And 109 prompt tokens are processed + Scenario Outline: OAI Compatibility Given a model @@ -44,11 +60,14 @@ Feature: llama.cpp server And streaming is Given an OAI compatible chat completions request with no api error Then tokens are predicted matching + And prompt tokens are processed + And the completion is truncated Examples: Prompts - | model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming | - | llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled | - | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled | + | model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated | + | llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not | + | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird)+ | -1 | 64 | enabled | | + Scenario: Tokenize / Detokenize When tokenizing: diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index d7f005836..0076f805b 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -196,12 +196,30 @@ async def step_request_completion(context, api_error): @step(u'{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): - assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content) + context.completion = context.tasks_result.pop() + assert_n_tokens_predicted(context.completion, predicted_n, re_content) @step(u'{predicted_n:d} tokens are predicted') def step_n_tokens_predicted(context, predicted_n): - assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n) + context.completion = context.tasks_result.pop() + assert_n_tokens_predicted(context.completion, predicted_n) + + +@step(u'the completion is truncated') +def step_assert_completion_truncated(context): + step_assert_completion_truncated(context, '') + + +@step(u'the completion is {truncated} truncated') +def step_assert_completion_truncated(context, truncated): + truncated = truncated != "not" + assert context.completion['truncated'] == truncated, f'{context.completion}' + + +@step(u'{n_prompt:d} prompt tokens are processed') +def step_impl(context, n_prompt): + assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}" @step(u'a user prompt {user_prompt}') @@ -722,7 +740,8 @@ async def oai_chat_completions(user_prompt, completion_response = { 'content': '', 'timings': { - 'predicted_n': 0 + 'predicted_n': 0, + 'prompt_n': 0 } } if async_client: @@ -763,7 +782,8 @@ async def oai_chat_completions(user_prompt, completion_response = { 'content': chat_completion_raw['choices'][0]['message'], 'timings': { - 'predicted_n': chat_completion_raw['usage']['completion_tokens'] + 'predicted_n': chat_completion_raw['usage']['completion_tokens'], + 'prompt_n': chat_completion_raw['usage']['prompt_tokens'] } } else: @@ -792,13 +812,16 @@ async def oai_chat_completions(user_prompt, if 'content' in delta: completion_response['content'] += delta['content'] completion_response['timings']['predicted_n'] += 1 + completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop' else: assert len(chat_completion.choices) == 1 completion_response = { 'content': chat_completion.choices[0].message.content, 'timings': { - 'predicted_n': chat_completion.usage.completion_tokens - } + 'predicted_n': chat_completion.usage.completion_tokens, + 'prompt_n': chat_completion.usage.prompt_tokens + }, + 'truncated': chat_completion.choices[0].finish_reason != 'stop' } if debug: print("OAI response formatted to llama.cpp:", completion_response) diff --git a/ggml.c b/ggml.c index 92b17ee6e..6eff98ab6 100644 --- a/ggml.c +++ b/ggml.c @@ -1841,6 +1841,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FLASH_ATTN", "FLASH_FF", "FLASH_ATTN_BACK", + "SSM_CONV", + "SSM_SCAN", "WIN_PART", "WIN_UNPART", "GET_REL_POS", @@ -1863,7 +1865,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1929,6 +1931,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_attn(x)", "flash_ff(x)", "flash_attn_back(x)", + "ssm_conv(x)", + "ssm_scan(x)", "win_part(x)", "win_unpart(x)", "get_rel_pos(x)", @@ -1951,7 +1955,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6154,6 +6158,108 @@ struct ggml_tensor * ggml_flash_attn_back( return result; } +// ggml_ssm_conv + +struct ggml_tensor * ggml_ssm_conv( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * c, + struct ggml_tensor * sq) { + GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(ggml_is_matrix(x)); + GGML_ASSERT(ggml_is_matrix(c)); + GGML_ASSERT(ggml_is_matrix(sq)); + GGML_ASSERT(sq->type == GGML_TYPE_I32); + + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_tokens = x->ne[1]; + const int64_t n_kv = s->ne[2]; + + GGML_ASSERT( s->ne[0] == d_conv - 1); + GGML_ASSERT( s->ne[1] == d_inner); + GGML_ASSERT( x->ne[0] == d_inner); + GGML_ASSERT(sq->ne[0] == n_kv); + GGML_ASSERT(sq->ne[1] == n_tokens); + + bool is_node = false; + + if (s->grad || x->grad || c->grad || sq->grad) { + GGML_ASSERT(false); // TODO: implement + is_node = true; + } + + // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + + result->op = GGML_OP_SSM_CONV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = s; + result->src[1] = x; + result->src[2] = c; + result->src[3] = sq; + + return result; +} + +// ggml_ssm_scan + +struct ggml_tensor * ggml_ssm_scan( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * dt, + struct ggml_tensor * A, + struct ggml_tensor * B, + struct ggml_tensor * C, + struct ggml_tensor * sq) { + GGML_ASSERT(ggml_is_contiguous(s)); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(dt)); + GGML_ASSERT(ggml_is_contiguous(A)); + GGML_ASSERT(sq->type == GGML_TYPE_I32); + GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); + GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); + GGML_ASSERT(ggml_are_same_shape(x, dt)); + + { + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_tokens = x->ne[1]; + + GGML_ASSERT(x->ne[0] == d_inner); + GGML_ASSERT(A->ne[0] == d_state); + GGML_ASSERT(A->ne[1] == d_inner); + GGML_ASSERT(B->ne[0] == d_state); + GGML_ASSERT(B->ne[1] == n_tokens); + GGML_ASSERT(C->ne[0] == d_state); + GGML_ASSERT(C->ne[1] == n_tokens); + } + + bool is_node = false; + + if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { + GGML_ASSERT(false); // TODO: implement + is_node = true; + } + + // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + + result->op = GGML_OP_SSM_SCAN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = s; + result->src[1] = x; + result->src[2] = dt; + result->src[3] = A; + result->src[4] = B; + result->src[5] = C; + result->src[6] = sq; + + return result; +} + // ggml_win_part struct ggml_tensor * ggml_win_part( @@ -14771,6 +14877,257 @@ static void ggml_compute_forward_flash_attn_back( } } +// ggml_compute_forward_ssm_conv + +static void ggml_compute_forward_ssm_conv_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const struct ggml_tensor * src0 = dst->src[0]; // conv_state + const struct ggml_tensor * src1 = dst->src[1]; // x + const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight + const struct ggml_tensor * src3 = dst->src[3]; // state_seq + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src2->ne[0]; // d_conv + const int nr = src0->ne[1]; // d_inner + const int n_t = src1->ne[1]; // n_tokens + const int n_kv = src0->ne[2]; // max number of sequences in the batch + + GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // for use with the destination state offset between sequences + GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + const int ir = ir1 - ir0; + + if (n_kv > 1) { + // multiple sequences means it's hard to know when it's the first time a state is read, + // so copy them all over to the destination, just to be sure. + for (int i3 = 0; i3 < n_kv; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); + // can't use memcpy because of d_conv vs d_conv - 1 + for (int i1 = 0; i1 < ir; ++i1) { + for (int i0 = 0; i0 < nc - 1; ++i0) { + // copy s0 to last (d_conv - 1) columns of s + s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; + } + } + } + } + + for (int i2 = 0; i2 < n_t; ++i2) { + int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} + float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} + float * s0; // {d_conv - 1, d_inner, n_kv} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} + int ne0s0; + + GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + + // avoid needing to copy the state for the first token + if (i2 == 0) { + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} + ne0s0 = src0->ne[0]; + } else { + // the source is the last (d_conv - 1) columns of the destination + s0 = s + 1; + ne0s0 = nc; + } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // shift state left + for (int i0 = 0; i0 < nc - 1; ++i0) { + s[i0 + i1*nc] = s0[i0 + i1*ne0s0]; + } + // insert x on the last column + s[(nc - 1) + i1*nc] = x0[i1]; + } + + // handle copies when there are multiple output states + for (int i3 = 1; i3 < n_kv; ++i3) { + int32_t seq = sq[i3]; + if (0 <= seq && seq < n_kv) { + float * s1 = s + (seq - sq[0])*nc*nr; + memcpy(s1, s, nc*ir*sizeof(float)); + } else { + // stop at negative or too big seq_ids + break; + } + } + + // it seems a little faster when this is separate from the state shift + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + float sumf = 0.0f; + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + sumf += s[i] * c[i]; + } + x[i1] = sumf; + } + } +} + +static void ggml_compute_forward_ssm_conv( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_ssm_scan + +static void ggml_compute_forward_ssm_scan_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const struct ggml_tensor * src0 = dst->src[0]; // s + const struct ggml_tensor * src1 = dst->src[1]; // x + const struct ggml_tensor * src2 = dst->src[2]; // dt + const struct ggml_tensor * src3 = dst->src[3]; // A + const struct ggml_tensor * src4 = dst->src[4]; // B + const struct ggml_tensor * src5 = dst->src[5]; // C + const struct ggml_tensor * src6 = dst->src[6]; // sq + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens in the batch + const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch + + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(float)); + GGML_ASSERT(src4->nb[0] == sizeof(float)); + GGML_ASSERT(src5->nb[0] == sizeof(float)); + // required for the dot product between s and C, and when copying the states + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[2]) + GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + const int ir = ir1 - ir0; + + if (n_kv > 1) { + // it's hard to know if the source states have already been copied + // when there are multiple, so copy them already. + for (int i3 = 0; i3 < n_kv; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); + memcpy(s, s0, nc*ir*sizeof(float)); + } + } + + for (int i2 = 0; i2 < n_t; ++i2) { + int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} + float * s0; + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} + + GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + + // avoid needing to copy the state for the first token + if (i2 == 0) { + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} + } else { + // otherwise the source is the same as the destination + s0 = s; + } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; + } + + // handle copies when there are multiple output states + for (int i3 = 1; i3 < n_kv; ++i3) { + int32_t seq = sq[i3]; + if (0 <= seq && seq < n_kv) { + float * s1 = s + (seq - sq[0])*nc*nr; + memcpy(s1, s, nc*ir*sizeof(float)); + } else { + // stop at negative or too big seq_ids + break; + } + } + } +} + +static void ggml_compute_forward_ssm_scan( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_scan_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_win_part static void ggml_compute_forward_win_part_f32( @@ -15830,6 +16187,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm bool masked = t != 0; ggml_compute_forward_flash_attn_back(params, masked, tensor); } break; + case GGML_OP_SSM_CONV: + { + ggml_compute_forward_ssm_conv(params, tensor); + } break; + case GGML_OP_SSM_SCAN: + { + ggml_compute_forward_ssm_scan(params, tensor); + } break; case GGML_OP_WIN_PART: { ggml_compute_forward_win_part(params, tensor); @@ -16884,6 +17249,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // not supported } break; + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_UNARY: @@ -17590,6 +17960,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: diff --git a/ggml.h b/ggml.h index 0ea4f8847..a13b0cec4 100644 --- a/ggml.h +++ b/ggml.h @@ -472,6 +472,8 @@ extern "C" { GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, + GGML_OP_SSM_CONV, + GGML_OP_SSM_SCAN, GGML_OP_WIN_PART, GGML_OP_WIN_UNPART, GGML_OP_GET_REL_POS, @@ -1728,6 +1730,23 @@ extern "C" { struct ggml_tensor * c0, struct ggml_tensor * c1); + GGML_API struct ggml_tensor * ggml_ssm_conv( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * c, + struct ggml_tensor * sq); + + GGML_API struct ggml_tensor * ggml_ssm_scan( + struct ggml_context * ctx, + struct ggml_tensor * s, + struct ggml_tensor * x, + struct ggml_tensor * dt, + struct ggml_tensor * A, + struct ggml_tensor * B, + struct ggml_tensor * C, + struct ggml_tensor * sq); + // partition into non-overlapping windows with padding if needed // example: // a: 768 64 64 1 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a62139811..b23badb10 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -61,6 +61,12 @@ class Keys: SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + class SSM: + CONV_KERNEL = "{arch}.ssm.conv_kernel" + INNER_SIZE = "{arch}.ssm.inner_size" + STATE_SIZE = "{arch}.ssm.state_size" + TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + class Tokenizer: MODEL = "tokenizer.ggml.model" LIST = "tokenizer.ggml.tokens" @@ -113,6 +119,7 @@ class MODEL_ARCH(IntEnum): MINICPM = auto() GEMMA = auto() STARCODER2 = auto() + MAMBA = auto() class MODEL_TENSOR(IntEnum): @@ -144,6 +151,13 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() + SSM_IN = auto() + SSM_CONV1D = auto() + SSM_X = auto() + SSM_DT = auto() + SSM_A = auto() + SSM_D = auto() + SSM_OUT = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -171,6 +185,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MINICPM: "minicpm", MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.MAMBA: "mamba", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -202,6 +217,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", + MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", + MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", + MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -543,6 +565,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.MAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + ], # TODO } @@ -734,6 +769,12 @@ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED +# SSM +KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL +KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE +KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE +KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK + # tokenization KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 801160832..e49c5db68 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -382,6 +382,18 @@ class GGUFWriter: def add_rope_scaling_finetuned(self, value: bool) -> None: self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) + def add_ssm_conv_kernel(self, value: int) -> None: + self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value) + + def add_ssm_inner_size(self, value: int) -> None: + self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value) + + def add_ssm_state_size(self, value: int) -> None: + self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value) + + def add_ssm_time_step_rank(self, value: int) -> None: + self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index db2ec9704..ed89955d8 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -20,6 +20,9 @@ class TensorNameMap: "wte", # gpt2 "transformer.embd.wte", # phi2 "model.tok_embeddings", # internlm2 + "model.embedding", # mamba-qbert + "backbone.embedding", # mamba + "backbone.embeddings", # mamba-hf ), # Token type embeddings @@ -44,7 +47,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -61,6 +64,8 @@ class TensorNameMap: "language_model.encoder.final_layernorm", # persimmon "model.final_layernorm", # persimmon "lm_head.ln", # phi2 + "model.norm_f", # mamba-qbert + "backbone.norm_f", # mamba ), # Rope frequencies @@ -86,6 +91,8 @@ class TensorNameMap: "transformer.h.{bid}.ln", # phi2 "model.layers.layers.{bid}.norm", # plamo "model.layers.{bid}.attention_norm", # internlm2 + "model.layers.{bid}.norm", # mamba-qbert + "backbone.layers.{bid}.norm", # mamba ), # Attention norm 2 @@ -282,7 +289,42 @@ class TensorNameMap: MODEL_TENSOR.LAYER_OUT_NORM: ( "encoder.layer.{bid}.output.LayerNorm", # bert "encoder.layers.{bid}.norm2", # nomic-bert - ) + ), + + MODEL_TENSOR.SSM_IN: ( + "model.layers.{bid}.in_proj", + "backbone.layers.{bid}.mixer.in_proj", + ), + + MODEL_TENSOR.SSM_CONV1D: ( + "model.layers.{bid}.conv1d", + "backbone.layers.{bid}.mixer.conv1d", + ), + + MODEL_TENSOR.SSM_X: ( + "model.layers.{bid}.x_proj", + "backbone.layers.{bid}.mixer.x_proj", + ), + + MODEL_TENSOR.SSM_DT: ( + "model.layers.{bid}.dt_proj", + "backbone.layers.{bid}.mixer.dt_proj", + ), + + MODEL_TENSOR.SSM_A: ( + "model.layers.{bid}.A_log", + "backbone.layers.{bid}.mixer.A_log", + ), + + MODEL_TENSOR.SSM_D: ( + "model.layers.{bid}.D", + "backbone.layers.{bid}.mixer.D", + ), + + MODEL_TENSOR.SSM_OUT: ( + "model.layers.{bid}.out_proj", + "backbone.layers.{bid}.mixer.out_proj", + ), } mapping: dict[str, tuple[MODEL_TENSOR, str]] diff --git a/llama.cpp b/llama.cpp index 4a20b7928..8c147a42b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -213,6 +213,7 @@ enum llm_arch { LLM_ARCH_MINICPM, LLM_ARCH_GEMMA, LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, LLM_ARCH_UNKNOWN, }; @@ -241,6 +242,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINICPM, "minicpm" }, { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -284,6 +286,11 @@ enum llm_kv { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_SSM_INNER_SIZE, + LLM_KV_SSM_CONV_KERNEL, + LLM_KV_SSM_STATE_SIZE, + LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, @@ -342,6 +349,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, + { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, + { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, + { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, @@ -399,6 +411,13 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, }; static const std::map> LLM_TENSOR_NAMES = { @@ -801,6 +820,22 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_MAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1613,6 +1648,12 @@ struct llama_hparams { float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; + // for State Space Models + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -1641,6 +1682,11 @@ struct llama_hparams { if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + if (this->ssm_d_conv != other.ssm_d_conv) return true; + if (this->ssm_d_inner != other.ssm_d_inner) return true; + if (this->ssm_d_state != other.ssm_d_state) return true; + if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + const float EPSILON = 1e-9f; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; @@ -1652,6 +1698,9 @@ struct llama_hparams { } uint32_t n_gqa() const { + if (n_head_kv == 0) { + return 0; + } return n_head/n_head_kv; } @@ -1662,6 +1711,18 @@ struct llama_hparams { uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads return n_embd_head_v * n_head_kv; } + + uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + // corresponds to Mamba's conv_states size + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + } + + uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; + } }; struct llama_cparams { @@ -1739,11 +1800,27 @@ struct llama_layer { struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_act; + + // mamba proj + struct ggml_tensor * ssm_in; + struct ggml_tensor * ssm_x; + struct ggml_tensor * ssm_dt; + struct ggml_tensor * ssm_out; + + // mamba + struct ggml_tensor * ssm_conv1d; + struct ggml_tensor * ssm_a; + struct ggml_tensor * ssm_d; + + // mamba bias + struct ggml_tensor * ssm_conv1d_b; + struct ggml_tensor * ssm_dt_b; }; struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; + int32_t src = 0; // used by recurrent state models to copy states std::set seq_id; @@ -1764,6 +1841,9 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; + bool do_copy = false; + // with recurrent state models, a cell can hold the state for more than one past token + bool recurrent = false; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2003,11 +2083,14 @@ struct llama_context { struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch] - struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] - struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx] - struct ggml_tensor * inp_K_shift; // I32 [n_ctx] + struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_KQ_pos; // F32 [kv_size] + struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [kv_size] + struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -2023,25 +2106,42 @@ static bool llama_kv_cache_init( const llama_model & model, ggml_type type_k, ggml_type type_v, - uint32_t n_ctx, + uint32_t kv_size, bool offload) { const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); const int64_t n_layer = hparams.n_layer; cache.has_shift = false; + // TODO: find a nicer way to add other recurrent model architectures + cache.recurrent = model.arch == LLM_ARCH_MAMBA; + + // TODO: support mixed reccurent Transformer architectues + // NOTE: (!a || b) is a logical implication (a -> b) + GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); + GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); + GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); + GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); + cache.head = 0; - cache.size = n_ctx; + cache.size = kv_size; cache.used = 0; cache.type_k = type_k; cache.type_v = type_v; cache.cells.clear(); - cache.cells.resize(n_ctx); + cache.cells.resize(kv_size); + + if (cache.recurrent) { + // init state copy sources + for (uint32_t i = 0; i < cache.size; ++i) { + cache.cells[i].src = i; + } + } #ifdef GGML_USE_CLBLAST offload = false; @@ -2080,8 +2180,8 @@ static bool llama_kv_cache_init( for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); @@ -2115,6 +2215,54 @@ static bool llama_kv_cache_find_slot( const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; + if (cache.recurrent) { + // For recurrent state architectures (like Mamba), + // each KV cache cell can store the state for a whole sequence. + + llama_seq_id min = cache.size - 1; + llama_seq_id max = 0; + + for (uint32_t i = 0; i < n_tokens; ++i) { + for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; + // make sure it's a valid seq_id + if ((uint32_t) seq_id < cache.size) { + if (seq_id > max) { + max = seq_id; + } + if (seq_id < min) { + min = seq_id; + } + // Assuming the tokens are in-order + if (batch.pos[i] != cache.cells[seq_id].pos + 1) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); + } + if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { + cache.used += 1; + } + cache.cells[seq_id].pos = batch.pos[i]; + // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set + } else { + // too big seq_id + // TODO: would it be possible to resize the KV cache size instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + return false; + } + } + } + + // allow getting the range of used cells, from head to head + n + cache.head = min; + cache.n = max - min + 1; + + // sanity check + return max >= min; + } + // otherwise, one cell per token. + if (n_tokens > n_ctx) { LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); return false; @@ -2184,7 +2332,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { cache.used = 0; } -static void llama_kv_cache_seq_rm( +static bool llama_kv_cache_seq_rm( struct llama_kv_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -2194,6 +2342,25 @@ static void llama_kv_cache_seq_rm( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + // models like Mamba can't have a state partially erased + if (cache.recurrent) { + if (seq_id >= (int64_t) cache.size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + // partial intersection is invalid + if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) { + return false; + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + } + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { if (seq_id < 0) { @@ -2215,6 +2382,8 @@ static void llama_kv_cache_seq_rm( // If we freed up a slot, set head to it so searching can start there. if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + + return true; } static void llama_kv_cache_seq_cp( @@ -2226,6 +2395,29 @@ static void llama_kv_cache_seq_cp( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + if (cache.recurrent) { + if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { + seq_id_src = cache.cells[seq_id_src].src; + GGML_ASSERT((uint32_t) seq_id_src < cache.size); + // intent to "copy from" + // supports copy chains thanks to taking the source of the source + cache.cells[seq_id_dst].src = seq_id_src; + + // preserve the "keep or clear" status of the copied sequence + if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { + cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); + } else { + cache.cells[seq_id_dst].seq_id.erase(seq_id_dst); + } + + cache.do_copy = true; + + cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; + } + return; + } + // otherwise, this is the KV cache of a Transformer-like model + cache.head = 0; for (uint32_t i = 0; i < cache.size; ++i) { @@ -2265,6 +2457,17 @@ static void llama_kv_cache_seq_add( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + if (cache.recurrent) { + // for Mamba-like models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) cache.size) { + llama_kv_cell & cell = cache.cells[seq_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } + } + return; + } + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.has_shift = true; @@ -2298,6 +2501,17 @@ static void llama_kv_cache_seq_div( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + if (cache.recurrent) { + // for Mamba-like models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) cache.size) { + llama_kv_cell & cell = cache.cells[seq_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + return; + } + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.has_shift = true; @@ -3117,7 +3331,7 @@ static void llm_load_hparams( // sanity check for n_rot (optional) { - hparams.n_rot = hparams.n_embd / hparams.n_head; + hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); @@ -3130,10 +3344,10 @@ static void llm_load_hparams( // gpt-j n_rot = rotary_dim } - hparams.n_embd_head_k = hparams.n_embd / hparams.n_head; + hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - hparams.n_embd_head_v = hparams.n_embd / hparams.n_head; + hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); // arch-specific KVs @@ -3383,6 +3597,36 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: model.type = e_model::MODEL_SMALL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: model.type = e_model::MODEL_MEDIUM; break; + case 1536: model.type = e_model::MODEL_LARGE; break; + case 2048: model.type = e_model::MODEL_XL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -3702,6 +3946,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); if (ml.n_elements >= 1e12) { @@ -4609,6 +4857,57 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}); } } break; + case LLM_ARCH_MAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); + + layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); + + layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); + + layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -4834,6 +5133,8 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(kv.size == n_ctx); + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed @@ -5043,6 +5344,8 @@ static struct ggml_tensor * llm_build_kqv( cb(kq, "kq_soft_max_ext", il); } + GGML_ASSERT(kv.size == n_ctx); + // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -5190,8 +5493,8 @@ struct llm_build_context { norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), - n_kv (worst_case ? n_ctx : kv_self.n), - kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), + n_kv (worst_case ? kv_self.size : kv_self.n), + kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -5220,6 +5523,8 @@ struct llm_build_context { struct ggml_cgraph * build_k_shift() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + GGML_ASSERT(kv_self.size == n_ctx); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions @@ -5238,6 +5543,27 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_s_copy() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + GGML_ASSERT(kv_self.recurrent); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); + + conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy); + ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy); + + // TODO: name the intermediate tensors with cb() + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); + } + + return gf; + } + struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -7835,6 +8161,145 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_mamba() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t d_model = n_embd; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + GGML_ASSERT(2 * d_model == d_inner); + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); + cb(inpL, "inp_embd", -1); + + struct ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0); + struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0); + + for (int il = 0; il < n_layer; ++il) { + // (ab)using the KV cache to store the states + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); + + // clear states of sequences which are starting at the beginning of this batch + { + conv_states = ggml_mul(ctx0, + ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]), + state_mask); + ssm_states = ggml_mul(ctx0, + ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]), + state_mask); + } + + conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); + ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} + struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); + // split the above in two + // => {d_inner, n_tokens} + struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); + struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + + // conv + { + // Custom operator which is needed only to ease simultaneous sequence processing. + // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weigth, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // The new conv_states is the last (d_conv - 1) columns + // of the last 3rd dimensional "layer" of the self-overlapping view. + // For simultaneous sequences, it's more complicated. + struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); + + // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + + // extract x from x_conv + x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); + + // bias + x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); + + x = ggml_silu(ctx0, x); + } + + // ssm + { + // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} + struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); + // split + struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); + struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + + // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} + dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, + // because only a single tensor can be returned. + struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); + + // store last states (the second part of y_ssm_states) + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), + ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states)))); + + struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); + + // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); + + // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} + cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -7871,6 +8336,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { return result; } +static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { + llama_batch dummy; + dummy.n_tokens = 0; + + llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; + + struct llm_build_context llm(lctx, dummy, cb, false); + + llm.init(); + + struct ggml_cgraph * result = llm.build_s_copy(); + + llm.free(); + + return result; +} + static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch, @@ -7985,6 +8467,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_starcoder2(); } break; + case LLM_ARCH_MAMBA: + { + result = llm.build_mamba(); + } break; default: GGML_ASSERT(false); } @@ -7995,19 +8481,29 @@ static struct ggml_cgraph * llama_build_graph( } static void llama_set_k_shift(llama_context & lctx) { - const auto & cparams = lctx.cparams; - - const int64_t n_ctx = cparams.n_ctx; + const int64_t kv_size = lctx.kv_self.size; assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); int32_t * data = (int32_t *) lctx.inp_K_shift->data; - for (int i = 0; i < n_ctx; ++i) { + for (int i = 0; i < kv_size; ++i) { data[i] = lctx.kv_self.cells[i].delta; } } +static void llama_set_s_copy(llama_context & lctx) { + const int64_t kv_size = lctx.kv_self.size; + + assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + for (int i = 0; i < kv_size; ++i) { + data[i] = lctx.kv_self.cells[i].src; + } +} + static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // // set input data @@ -8044,6 +8540,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; + // For causal attention, use only the previous KV cells + // of the correct sequence for each token of the batch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; @@ -8149,6 +8648,53 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } + + if (kv_self.recurrent) { + const int64_t n_kv = kv_self.n; + + { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); + float * data = (float *) lctx.inp_s_mask->data; + + // states which are not affected by the current batch are left untouched + for (int i = 0; i < n_kv; ++i) { + llama_seq_id seq_id = i + lctx.kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; + bool has_self_seq = kv_cell.has_seq_id(seq_id); + + data[i] = (float) has_self_seq; + + // ensure current sequences will be kept + if (!has_self_seq && kv_cell.pos >= 0) { + kv_cell.seq_id.insert(seq_id); + } + } + } + // For Mamba (and other recurrent architectures), + // update the correct state(s)/sequence(s) for each token of the batch. + // Like with the KQ_mask, if a token in the batch has multiple sequences, + // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). + { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_seq->data; + + for (int j = 0; j < n_tokens; ++j) { + const int32_t n_seq = batch.n_seq_id[j]; + GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence + + for (int i = 0; i < n_kv; ++i) { + if (i < n_seq) { + // for this type of model, the head is the minimum seq_id of the batch + data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; + } else { + data[j*n_kv + i] = -1; + } + } + } + } + } } static void llama_graph_compute( @@ -8271,11 +8817,13 @@ static int llama_decode_internal( return 1; } - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); + if (!kv_self.recurrent) { + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); + } } //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); @@ -8701,6 +9249,26 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } + if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { + llama_set_s_copy(lctx); + + { + ggml_cgraph * gf = llama_build_graph_s_copy(lctx); + + llama_graph_compute(lctx, gf, lctx.cparams.n_threads); + } + + { + auto & kv_self = lctx.kv_self; + + kv_self.do_copy = false; + + for (uint32_t i = 0; i < kv_self.size; ++i) { + kv_self.cells[i].src = i; + } + } + } + // defragment the KV cache if needed if (lctx.kv_self.do_defrag) { llama_kv_cache_defrag_internal(lctx); @@ -11535,6 +12103,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); + // do not quantize Mamba's small yet 2D weights + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + quantize &= name.find("ssm_x.weight") == std::string::npos; + quantize &= name.find("ssm_dt.weight") == std::string::npos; + enum ggml_type new_type; void * new_data; size_t new_size; @@ -11985,6 +12559,7 @@ struct llama_context_params llama_context_default_params() { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, + /*.n_parallel =*/ 1, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, @@ -12146,6 +12721,7 @@ struct llama_context * llama_new_context_with_model( auto & cparams = ctx->cparams; cparams.n_batch = params.n_batch; + // TODO: maybe add n_parallel here too cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -12203,8 +12779,18 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - const ggml_type type_k = params.type_k; - const ggml_type type_v = params.type_v; + uint32_t kv_size = cparams.n_ctx; + ggml_type type_k = params.type_k; + ggml_type type_v = params.type_v; + + // Mamba only needs a constant number of KV cache cells per sequence + if (model->arch == LLM_ARCH_MAMBA) { + // Mamba needs at least as many KV cells as there are sequences kept at any time + kv_size = std::max((uint32_t) 1, params.n_parallel); + // it's probably best to keep as much precision as possible for the states + type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states + type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states + } GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); @@ -12304,7 +12890,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -12338,7 +12924,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*8, + /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)), /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -12347,11 +12933,16 @@ struct llama_context * llama_new_context_with_model( ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); - ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); - ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx); - ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); + ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, kv_size, cparams.n_batch); + ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); + ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); + if (ctx->kv_self.recurrent) { + ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); + ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); + ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch); + } ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); @@ -12361,6 +12952,11 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_mean, "inp_mean"); ggml_set_name(ctx->inp_cls, "inp_cls"); + if (ctx->kv_self.recurrent) { + ggml_set_name(ctx->inp_s_copy, "inp_s_copy"); + ggml_set_name(ctx->inp_s_mask, "inp_s_mask"); + ggml_set_name(ctx->inp_s_seq, "inp_s_seq"); + } ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__, @@ -12447,6 +13043,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) { return ctx->cparams.n_batch; } +uint32_t llama_n_max_seq(const struct llama_context * ctx) { + return ctx->kv_self.size; +} + enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { return model->vocab.type; } @@ -12460,6 +13060,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_MPT: case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: + case LLM_ARCH_MAMBA: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -12713,8 +13314,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) { llama_kv_cache_clear(ctx->kv_self); } -void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); +bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { @@ -12891,8 +13492,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); const size_t kv_buf_size = kv_self.total_size(); const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); @@ -12913,6 +13514,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); + if (kv_self.recurrent) { + // v is contiguous for recurrent models + // TODO: use other tensors for state models than k and v + const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); + + tmp_buf.resize(v_size); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size()); + data_ctx->write(tmp_buf.data(), tmp_buf.size()); + continue; + } + // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size); @@ -13005,8 +13617,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); size_t kv_buf_size; uint32_t kv_head; @@ -13027,6 +13639,16 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; + if (kv_self.recurrent) { + // v is contiguous for recurrent models + // TODO: use other tensors for state models than k and v + const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); + + ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size); + inp += v_size; + continue; + } + // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size); diff --git a/llama.h b/llama.h index 3dc162b07..7a107c7f3 100644 --- a/llama.h +++ b/llama.h @@ -235,6 +235,7 @@ extern "C" { uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // prompt processing maximum batch size + uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models) uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing @@ -376,6 +377,7 @@ extern "C" { LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_max_seq (const struct llama_context * ctx); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); @@ -502,7 +504,7 @@ extern "C" { // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_rm( + LLAMA_API bool llama_kv_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0,