Merge branch 'master' into xsn/cleanup_oai
This commit is contained in:
commit
51f80eb7ee
21 changed files with 1542 additions and 135 deletions
4
Makefile
4
Makefile
|
@ -201,6 +201,10 @@ ifdef LLAMA_SERVER_VERBOSE
|
|||
MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
|
||||
endif
|
||||
|
||||
ifdef LLAMA_SERVER_SSL
|
||||
MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT
|
||||
MK_LDFLAGS += -lssl -lcrypto
|
||||
endif
|
||||
|
||||
ifdef LLAMA_CODE_COVERAGE
|
||||
MK_CXXFLAGS += -fprofile-arcs -ftest-coverage -dumpbase ''
|
||||
|
|
|
@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
|||
|
||||
### Recent API changes
|
||||
|
||||
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_max_seq()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
|
||||
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
|
||||
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849
|
||||
|
||||
|
@ -110,6 +111,7 @@ Typically finetunes of the base models below are supported as well.
|
|||
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)
|
||||
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
|
||||
- [x] [Gemma](https://ai.google.dev/gemma)
|
||||
- [x] [Mamba](https://github.com/state-spaces/mamba)
|
||||
|
||||
**Multimodal models:**
|
||||
|
||||
|
|
|
@ -1288,6 +1288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
|
||||
cparams.n_ctx = params.n_ctx;
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_parallel = params.n_parallel;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||
cparams.seed = params.seed;
|
||||
|
|
|
@ -1847,6 +1847,124 @@ class StarCoder2Model(Model):
|
|||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
||||
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel")
|
||||
class MambaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
||||
def set_vocab(self):
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
# Round vocab size to next multiple of 8
|
||||
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 8)
|
||||
# pad using ceiling division
|
||||
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
|
||||
self.hparams["vocab_size"] = vocab_size
|
||||
|
||||
if (self.dir_model / "tokenizer.json").is_file():
|
||||
self._set_vocab_gpt2()
|
||||
else:
|
||||
# Use the GPT-NeoX tokenizer when no tokenizer files are present
|
||||
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
|
||||
print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||
neox_reader = gguf.GGUFReader(tokenizer_path, "r")
|
||||
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
|
||||
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
|
||||
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
|
||||
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
|
||||
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
|
||||
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
|
||||
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
|
||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
|
||||
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "d_model"])
|
||||
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
|
||||
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
|
||||
# ceiling division
|
||||
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
|
||||
# Fail early for models which don't have a block expansion factor of 2
|
||||
assert d_inner == 2 * d_model
|
||||
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||
self.gguf_writer.add_embedding_length(d_model)
|
||||
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||
self.gguf_writer.add_ssm_state_size(d_state)
|
||||
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def write_tensors(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||
|
||||
tok_embd = None
|
||||
tok_embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD] + ".weight"
|
||||
output_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight"
|
||||
|
||||
for name, data_torch in self.get_tensors():
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print(f"Can not map tensor {name!r}")
|
||||
sys.exit()
|
||||
|
||||
if name.endswith(".A_log"):
|
||||
print("A_log --> A ==> " + new_name)
|
||||
data_torch = -torch.exp(data_torch)
|
||||
|
||||
# assuming token_embd.weight is seen before output.weight
|
||||
if tok_embd is not None and new_name == output_name:
|
||||
if torch.equal(tok_embd, data_torch):
|
||||
print(f"{output_name} is equivalent to {tok_embd_name}, omitting")
|
||||
continue
|
||||
if new_name == tok_embd_name:
|
||||
tok_embd = data_torch
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
|
||||
# if f32 desired, convert any float16 to float32
|
||||
if self.ftype == 0 and data_dtype == np.float16:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# if f16 desired, convert big float32 2-dim weight tensors to float16
|
||||
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
|
||||
data = data.astype(np.float16)
|
||||
|
||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
@ -105,6 +105,9 @@ int main(int argc, char ** argv) {
|
|||
ctx_params.n_threads = params.n_threads;
|
||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||
|
||||
// ensure enough sequences are available
|
||||
ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());
|
||||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx == NULL) {
|
||||
|
@ -174,10 +177,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_batch_clear(batch);
|
||||
|
||||
const int n_tokens = is_pp_shared ? pp : pl*pp;
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
llama_batch_add(batch, 0, i, { 0 }, false);
|
||||
for (int i = 0; i < pp; ++i) {
|
||||
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
|
||||
llama_batch_add(batch, 0, i, { j }, false);
|
||||
}
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
|
@ -192,7 +195,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (is_pp_shared) {
|
||||
for (int32_t i = 1; i < pl; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@ int main(int argc, char ** argv) {
|
|||
ctx_params.seed = 1234;
|
||||
ctx_params.n_ctx = n_kv_req;
|
||||
ctx_params.n_batch = std::max(n_len, n_parallel);
|
||||
ctx_params.n_parallel = n_parallel;
|
||||
ctx_params.n_threads = params.n_threads;
|
||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||
|
||||
|
@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
|
|||
// assign the system KV cache to all parallel sequences
|
||||
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||
for (int32_t i = 1; i < n_parallel; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
if (n_parallel > 1) {
|
||||
|
|
|
@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
|
|||
// number of simultaneous "clients" to simulate
|
||||
const int32_t n_clients = params.n_parallel;
|
||||
|
||||
// dedicate one sequence to the system prompt
|
||||
params.n_parallel += 1;
|
||||
|
||||
// requests to simulate
|
||||
const int32_t n_seq = params.n_sequences;
|
||||
|
||||
|
@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
|
||||
for (int32_t i = 1; i <= n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
@ -221,15 +224,17 @@ int main(int argc, char ** argv) {
|
|||
|
||||
client.i_batch = batch.n_tokens;
|
||||
|
||||
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
|
||||
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
|
||||
|
||||
client.n_decoded += 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 0; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
|
||||
for (int i = 1; i <= n_clients; ++i) {
|
||||
llama_kv_cache_seq_rm(ctx, i, -1, -1);
|
||||
// but keep the system prompt
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: clearing the KV cache\n", __func__);
|
||||
|
@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
|
|||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
|
||||
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
|
@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
|
||||
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
|
|
|
@ -809,7 +809,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
const int n_batch = params.n_batch;
|
||||
|
||||
const int max_tasks_per_batch = 32;
|
||||
const int max_seq = 4*max_tasks_per_batch;
|
||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
|
||||
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||
|
||||
|
@ -1086,7 +1086,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||
const int n_batch = params.n_batch;
|
||||
|
||||
const int max_tasks_per_batch = 128;
|
||||
const int max_seq = 2*max_tasks_per_batch;
|
||||
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
|
||||
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||
|
||||
|
@ -1438,7 +1438,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|||
const int n_batch = params.n_batch;
|
||||
|
||||
const int max_tasks_per_batch = 32;
|
||||
const int max_seq = 4*max_tasks_per_batch;
|
||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
|
||||
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||
|
||||
|
@ -1815,6 +1815,9 @@ int main(int argc, char ** argv) {
|
|||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
||||
// ensure there's at least enough seq_ids for HellaSwag
|
||||
params.n_parallel = std::max(4, params.n_parallel);
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
if (model == NULL) {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
set(TARGET server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
@ -7,6 +8,11 @@ target_compile_definitions(${TARGET} PRIVATE
|
|||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
||||
if (LLAMA_SERVER_SSL)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)
|
||||
target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
endif()
|
||||
if (WIN32)
|
||||
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
||||
|
|
|
@ -59,6 +59,10 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
|
|||
- `--log-disable`: Output logs to stdout only, default: enabled.
|
||||
- `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json)
|
||||
|
||||
**If compiled with `LLAMA_SERVER_SSL=ON`**
|
||||
- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key
|
||||
- `--ssl-cert-file FNAME`: path to file a PEM-encoded SSL certificate
|
||||
|
||||
## Build
|
||||
|
||||
server is build alongside everything else from the root of the project
|
||||
|
@ -75,6 +79,28 @@ server is build alongside everything else from the root of the project
|
|||
cmake --build . --config Release
|
||||
```
|
||||
|
||||
## Build with SSL
|
||||
|
||||
server can also be built with SSL support using OpenSSL 3
|
||||
|
||||
- Using `make`:
|
||||
|
||||
```bash
|
||||
# NOTE: For non-system openssl, use the following:
|
||||
# CXXFLAGS="-I /path/to/openssl/include"
|
||||
# LDFLAGS="-L /path/to/openssl/lib"
|
||||
make LLAMA_SERVER_SSL=true server
|
||||
```
|
||||
|
||||
- Using `CMake`:
|
||||
|
||||
```bash
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DLLAMA_SERVER_SSL=ON
|
||||
make server
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
To get started right away, run the following command, making sure to use the correct path for the model you have:
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <signal.h>
|
||||
#include <memory>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
@ -118,6 +119,11 @@ struct server_params {
|
|||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
std::string ssl_key_file = "";
|
||||
std::string ssl_cert_file = "";
|
||||
#endif
|
||||
|
||||
bool slots_endpoint = true;
|
||||
bool metrics_endpoint = false;
|
||||
};
|
||||
|
@ -659,7 +665,11 @@ struct server_context {
|
|||
bool load_model(const gpt_params & params_) {
|
||||
params = params_;
|
||||
|
||||
// dedicate one sequence to the system prompt
|
||||
params.n_parallel += 1;
|
||||
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
params.n_parallel -= 1; // but be sneaky about it
|
||||
if (model == nullptr) {
|
||||
LOG_ERROR("unable to load model", {{"model", params.model}});
|
||||
return false;
|
||||
|
@ -1018,8 +1028,8 @@ struct server_context {
|
|||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < params.n_parallel; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
|
||||
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1124,6 +1134,7 @@ struct server_context {
|
|||
|
||||
LOG_VERBOSE("stopped by limit", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_decoded", slot.n_decoded},
|
||||
{"n_predict", slot.params.n_predict},
|
||||
});
|
||||
|
@ -1137,6 +1148,8 @@ struct server_context {
|
|||
}
|
||||
|
||||
LOG_VERBOSE("next token", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"token", result.tok},
|
||||
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
||||
{"has_next_token", slot.has_next_token},
|
||||
|
@ -1306,7 +1319,7 @@ struct server_context {
|
|||
const int n_embd = llama_n_embd(model);
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1633,8 +1646,8 @@ struct server_context {
|
|||
{"n_cache_tokens", slot.cache_tokens.size()}
|
||||
});
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||
|
@ -1666,7 +1679,7 @@ struct server_context {
|
|||
|
||||
// TODO: we always have to take into account the "system_tokens"
|
||||
// this is not great and needs to be improved somehow
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
|
||||
|
@ -1746,6 +1759,15 @@ struct server_context {
|
|||
slot.n_past = 0;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("prompt tokenized", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_prompt_tokens", slot.n_prompt_tokens},
|
||||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||
});
|
||||
|
||||
if (slot.embedding) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_batch) {
|
||||
|
@ -1784,10 +1806,13 @@ struct server_context {
|
|||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"n_prompt_tokens", slot.n_prompt_tokens},
|
||||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||
});
|
||||
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
|
@ -1804,9 +1829,6 @@ struct server_context {
|
|||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
|
@ -1837,8 +1859,28 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
const int p0 = (int) system_tokens.size() + slot.n_past;
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
||||
// keep only the common part
|
||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||
// could not partially delete (likely using a non-Transformer model)
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||
|
||||
p0 = (int) system_tokens.size();
|
||||
if (p0 != 0) {
|
||||
// copy over the system prompt when there is one
|
||||
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
||||
}
|
||||
|
||||
// there is no common part left (except for the system prompt)
|
||||
slot.n_past = 0;
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
// TODO: is the system prompt ever in the sampling context?
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
}
|
||||
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
LOG_INFO("kv cache rm [p0, end)", {
|
||||
{ "id_slot", slot.id },
|
||||
|
@ -1863,7 +1905,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
||||
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
|
@ -1937,9 +1979,9 @@ struct server_context {
|
|||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||
|
||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
|
||||
slot.n_past_se -= bd;
|
||||
|
||||
|
@ -2106,6 +2148,10 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
|||
printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
|
||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
printf(" --ssl-key-file FNAME path to file a PEM-encoded SSL private key\n");
|
||||
printf(" --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate\n");
|
||||
#endif
|
||||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
||||
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
||||
|
@ -2184,7 +2230,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||
}
|
||||
}
|
||||
key_file.close();
|
||||
} else if (arg == "--timeout" || arg == "-to") {
|
||||
|
||||
}
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
else if (arg == "--ssl-key-file") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.ssl_key_file = argv[i];
|
||||
} else if (arg == "--ssl-cert-file") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.ssl_cert_file = argv[i];
|
||||
}
|
||||
#endif
|
||||
else if (arg == "--timeout" || arg == "-to") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
|
@ -2622,23 +2685,36 @@ int main(int argc, char ** argv) {
|
|||
{"system_info", llama_print_system_info()},
|
||||
});
|
||||
|
||||
httplib::Server svr;
|
||||
std::unique_ptr<httplib::Server> svr;
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
if (sparams.ssl_key_file != "" && sparams.ssl_cert_file != "") {
|
||||
LOG_INFO("Running with SSL", {{"key", sparams.ssl_key_file}, {"cert", sparams.ssl_cert_file}});
|
||||
svr.reset(
|
||||
new httplib::SSLServer(sparams.ssl_cert_file.c_str(), sparams.ssl_key_file.c_str())
|
||||
);
|
||||
} else {
|
||||
LOG_INFO("Running without SSL", {});
|
||||
svr.reset(new httplib::Server());
|
||||
}
|
||||
#else
|
||||
svr.reset(new httplib::Server());
|
||||
#endif
|
||||
|
||||
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
||||
|
||||
svr.set_default_headers({{"Server", "llama.cpp"}});
|
||||
svr->set_default_headers({{"Server", "llama.cpp"}});
|
||||
|
||||
// CORS preflight
|
||||
svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
|
||||
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "POST");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
});
|
||||
|
||||
svr.set_logger(log_server_request);
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
const char fmt[] = "500 Internal Server Error\n%s";
|
||||
|
||||
char buf[BUFSIZ];
|
||||
|
@ -2654,7 +2730,7 @@ int main(int argc, char ** argv) {
|
|||
res.status = 500;
|
||||
});
|
||||
|
||||
svr.set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 401) {
|
||||
res.set_content("Unauthorized", "text/plain; charset=utf-8");
|
||||
}
|
||||
|
@ -2667,10 +2743,10 @@ int main(int argc, char ** argv) {
|
|||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
svr.set_read_timeout (sparams.read_timeout);
|
||||
svr.set_write_timeout(sparams.write_timeout);
|
||||
svr->set_read_timeout (sparams.read_timeout);
|
||||
svr->set_write_timeout(sparams.write_timeout);
|
||||
|
||||
if (!svr.bind_to_port(sparams.hostname, sparams.port)) {
|
||||
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
|
||||
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
|
||||
return 1;
|
||||
}
|
||||
|
@ -2761,7 +2837,7 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
// register server middlewares
|
||||
svr.set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!middleware_validate_api_key(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
|
@ -3296,7 +3372,7 @@ int main(int argc, char ** argv) {
|
|||
// register static assets routes
|
||||
if (!sparams.public_path.empty()) {
|
||||
// Set the base directory for serving static files
|
||||
svr.set_base_dir(sparams.public_path);
|
||||
svr->set_base_dir(sparams.public_path);
|
||||
}
|
||||
|
||||
// using embedded static files
|
||||
|
@ -3307,33 +3383,33 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
};
|
||||
|
||||
svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
|
||||
return res.set_content("", "application/json; charset=utf-8");
|
||||
});
|
||||
svr.Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||
svr.Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
||||
svr.Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||
svr.Get("/json-schema-to-grammar.mjs", handle_static_file(
|
||||
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
|
||||
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
|
||||
|
||||
// register API routes
|
||||
svr.Get ("/health", handle_health);
|
||||
svr.Get ("/slots", handle_slots);
|
||||
svr.Get ("/metrics", handle_metrics);
|
||||
svr.Get ("/props", handle_props);
|
||||
svr.Get ("/v1/models", handle_models);
|
||||
svr.Post("/completion", handle_completions); // legacy
|
||||
svr.Post("/completions", handle_completions);
|
||||
svr.Post("/v1/completions", handle_completions);
|
||||
svr.Post("/chat/completions", handle_chat_completions);
|
||||
svr.Post("/v1/chat/completions", handle_chat_completions);
|
||||
svr.Post("/infill", handle_infill);
|
||||
svr.Post("/embedding", handle_embeddings); // legacy
|
||||
svr.Post("/embeddings", handle_embeddings);
|
||||
svr.Post("/v1/embeddings", handle_embeddings);
|
||||
svr.Post("/tokenize", handle_tokenize);
|
||||
svr.Post("/detokenize", handle_detokenize);
|
||||
svr->Get ("/health", handle_health);
|
||||
svr->Get ("/slots", handle_slots);
|
||||
svr->Get ("/metrics", handle_metrics);
|
||||
svr->Get ("/props", handle_props);
|
||||
svr->Get ("/v1/models", handle_models);
|
||||
svr->Post("/completion", handle_completions); // legacy
|
||||
svr->Post("/completions", handle_completions);
|
||||
svr->Post("/v1/completions", handle_completions);
|
||||
svr->Post("/chat/completions", handle_chat_completions);
|
||||
svr->Post("/v1/chat/completions", handle_chat_completions);
|
||||
svr->Post("/infill", handle_infill);
|
||||
svr->Post("/embedding", handle_embeddings); // legacy
|
||||
svr->Post("/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
|
||||
//
|
||||
// Start the server
|
||||
|
@ -3343,13 +3419,13 @@ int main(int argc, char ** argv) {
|
|||
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||
}
|
||||
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
|
||||
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
|
||||
svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
|
||||
|
||||
LOG_INFO("HTTP server listening", log_data);
|
||||
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]() {
|
||||
if (!svr.listen_after_bind()) {
|
||||
if (!svr->listen_after_bind()) {
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
}
|
||||
|
@ -3390,7 +3466,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
svr.stop();
|
||||
svr->stop();
|
||||
t.join();
|
||||
|
||||
llama_backend_free();
|
||||
|
|
|
@ -6,8 +6,8 @@ Feature: Parallel
|
|||
Given a server listening on localhost:8080
|
||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||
And 42 as server seed
|
||||
And 512 as batch size
|
||||
And 64 KV cache size
|
||||
And 128 as batch size
|
||||
And 256 KV cache size
|
||||
And 2 slots
|
||||
And continuous batching
|
||||
Then the server is starting
|
||||
|
@ -76,6 +76,7 @@ Feature: Parallel
|
|||
| disabled | 128 |
|
||||
| enabled | 64 |
|
||||
|
||||
|
||||
Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
|
||||
Given a prompt:
|
||||
"""
|
||||
|
|
|
@ -10,11 +10,10 @@ Feature: llama.cpp server
|
|||
# KV Cache corresponds to the total amount of tokens
|
||||
# that can be stored across all independent sequences: #4130
|
||||
# see --ctx-size and #5568
|
||||
And 32 KV cache size
|
||||
And 512 as batch size
|
||||
And 1 slots
|
||||
And embeddings extraction
|
||||
And 32 server max tokens to predict
|
||||
And 256 KV cache size
|
||||
And 32 as batch size
|
||||
And 2 slots
|
||||
And 64 server max tokens to predict
|
||||
And prometheus compatible metrics exposed
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
@ -23,18 +22,35 @@ Feature: llama.cpp server
|
|||
Then the server is ready
|
||||
And all slots are idle
|
||||
|
||||
|
||||
Scenario Outline: Completion
|
||||
Given a prompt <prompt>
|
||||
And <n_predict> max tokens to predict
|
||||
And a completion request with no api error
|
||||
Then <n_predicted> tokens are predicted matching <re_content>
|
||||
And the completion is <truncated> truncated
|
||||
And <n_prompt> prompt tokens are processed
|
||||
And prometheus metrics are exposed
|
||||
And metric llamacpp:tokens_predicted is <n_predicted>
|
||||
|
||||
Examples: Prompts
|
||||
| prompt | n_predict | re_content | n_predicted |
|
||||
| I believe the meaning of life is | 8 | (read\|going)+ | 8 |
|
||||
| Write a joke about AI | 64 | (park\|friends\|scared\|always)+ | 32 |
|
||||
| prompt | n_predict | re_content | n_prompt | n_predicted | truncated |
|
||||
| I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not |
|
||||
| Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids)+ | 46 | 64 | not |
|
||||
|
||||
Scenario: Completion prompt truncated
|
||||
Given a prompt:
|
||||
"""
|
||||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||
"""
|
||||
And a completion request with no api error
|
||||
Then 64 tokens are predicted matching fun|Annaks|popcorns
|
||||
And the completion is truncated
|
||||
And 109 prompt tokens are processed
|
||||
|
||||
|
||||
Scenario Outline: OAI Compatibility
|
||||
Given a model <model>
|
||||
|
@ -44,11 +60,14 @@ Feature: llama.cpp server
|
|||
And streaming is <enable_streaming>
|
||||
Given an OAI compatible chat completions request with no api error
|
||||
Then <n_predicted> tokens are predicted matching <re_content>
|
||||
And <n_prompt> prompt tokens are processed
|
||||
And the completion is <truncated> truncated
|
||||
|
||||
Examples: Prompts
|
||||
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
|
||||
| llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled |
|
||||
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled |
|
||||
| model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated |
|
||||
| llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not |
|
||||
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird)+ | -1 | 64 | enabled | |
|
||||
|
||||
|
||||
Scenario: Tokenize / Detokenize
|
||||
When tokenizing:
|
||||
|
|
|
@ -196,12 +196,30 @@ async def step_request_completion(context, api_error):
|
|||
|
||||
@step(u'{predicted_n:d} tokens are predicted matching {re_content}')
|
||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
|
||||
context.completion = context.tasks_result.pop()
|
||||
assert_n_tokens_predicted(context.completion, predicted_n, re_content)
|
||||
|
||||
|
||||
@step(u'{predicted_n:d} tokens are predicted')
|
||||
def step_n_tokens_predicted(context, predicted_n):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
|
||||
context.completion = context.tasks_result.pop()
|
||||
assert_n_tokens_predicted(context.completion, predicted_n)
|
||||
|
||||
|
||||
@step(u'the completion is truncated')
|
||||
def step_assert_completion_truncated(context):
|
||||
step_assert_completion_truncated(context, '')
|
||||
|
||||
|
||||
@step(u'the completion is {truncated} truncated')
|
||||
def step_assert_completion_truncated(context, truncated):
|
||||
truncated = truncated != "not"
|
||||
assert context.completion['truncated'] == truncated, f'{context.completion}'
|
||||
|
||||
|
||||
@step(u'{n_prompt:d} prompt tokens are processed')
|
||||
def step_impl(context, n_prompt):
|
||||
assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}"
|
||||
|
||||
|
||||
@step(u'a user prompt {user_prompt}')
|
||||
|
@ -723,7 +741,8 @@ async def oai_chat_completions(user_prompt,
|
|||
completion_response = {
|
||||
'content': '',
|
||||
'timings': {
|
||||
'predicted_n': 0
|
||||
'predicted_n': 0,
|
||||
'prompt_n': 0
|
||||
}
|
||||
}
|
||||
if async_client:
|
||||
|
@ -764,7 +783,8 @@ async def oai_chat_completions(user_prompt,
|
|||
completion_response = {
|
||||
'content': chat_completion_raw['choices'][0]['message'],
|
||||
'timings': {
|
||||
'predicted_n': chat_completion_raw['usage']['completion_tokens']
|
||||
'predicted_n': chat_completion_raw['usage']['completion_tokens'],
|
||||
'prompt_n': chat_completion_raw['usage']['prompt_tokens']
|
||||
}
|
||||
}
|
||||
else:
|
||||
|
@ -793,13 +813,16 @@ async def oai_chat_completions(user_prompt,
|
|||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
|
||||
else:
|
||||
assert len(chat_completion.choices) == 1
|
||||
completion_response = {
|
||||
'content': chat_completion.choices[0].message.content,
|
||||
'timings': {
|
||||
'predicted_n': chat_completion.usage.completion_tokens
|
||||
}
|
||||
'predicted_n': chat_completion.usage.completion_tokens,
|
||||
'prompt_n': chat_completion.usage.prompt_tokens
|
||||
},
|
||||
'truncated': chat_completion.choices[0].finish_reason != 'stop'
|
||||
}
|
||||
if debug:
|
||||
print("OAI response formatted to llama.cpp:", completion_response)
|
||||
|
|
379
ggml.c
379
ggml.c
|
@ -1841,6 +1841,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"FLASH_ATTN",
|
||||
"FLASH_FF",
|
||||
"FLASH_ATTN_BACK",
|
||||
"SSM_CONV",
|
||||
"SSM_SCAN",
|
||||
"WIN_PART",
|
||||
"WIN_UNPART",
|
||||
"GET_REL_POS",
|
||||
|
@ -1863,7 +1865,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -1929,6 +1931,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"flash_attn(x)",
|
||||
"flash_ff(x)",
|
||||
"flash_attn_back(x)",
|
||||
"ssm_conv(x)",
|
||||
"ssm_scan(x)",
|
||||
"win_part(x)",
|
||||
"win_unpart(x)",
|
||||
"get_rel_pos(x)",
|
||||
|
@ -1951,7 +1955,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
@ -6154,6 +6158,108 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_ssm_conv
|
||||
|
||||
struct ggml_tensor * ggml_ssm_conv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * c,
|
||||
struct ggml_tensor * sq) {
|
||||
GGML_ASSERT(ggml_is_3d(s));
|
||||
GGML_ASSERT(ggml_is_matrix(x));
|
||||
GGML_ASSERT(ggml_is_matrix(c));
|
||||
GGML_ASSERT(ggml_is_matrix(sq));
|
||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
||||
|
||||
const int64_t d_conv = c->ne[0];
|
||||
const int64_t d_inner = c->ne[1];
|
||||
const int64_t n_tokens = x->ne[1];
|
||||
const int64_t n_kv = s->ne[2];
|
||||
|
||||
GGML_ASSERT( s->ne[0] == d_conv - 1);
|
||||
GGML_ASSERT( s->ne[1] == d_inner);
|
||||
GGML_ASSERT( x->ne[0] == d_inner);
|
||||
GGML_ASSERT(sq->ne[0] == n_kv);
|
||||
GGML_ASSERT(sq->ne[1] == n_tokens);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (s->grad || x->grad || c->grad || sq->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
|
||||
|
||||
result->op = GGML_OP_SSM_CONV;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = s;
|
||||
result->src[1] = x;
|
||||
result->src[2] = c;
|
||||
result->src[3] = sq;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_ssm_scan
|
||||
|
||||
struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * dt,
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * sq) {
|
||||
GGML_ASSERT(ggml_is_contiguous(s));
|
||||
GGML_ASSERT(ggml_is_contiguous(x));
|
||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||
GGML_ASSERT(ggml_is_contiguous(A));
|
||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
||||
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
||||
|
||||
{
|
||||
const int64_t d_state = s->ne[0];
|
||||
const int64_t d_inner = s->ne[1];
|
||||
const int64_t n_tokens = x->ne[1];
|
||||
|
||||
GGML_ASSERT(x->ne[0] == d_inner);
|
||||
GGML_ASSERT(A->ne[0] == d_state);
|
||||
GGML_ASSERT(A->ne[1] == d_inner);
|
||||
GGML_ASSERT(B->ne[0] == d_state);
|
||||
GGML_ASSERT(B->ne[1] == n_tokens);
|
||||
GGML_ASSERT(C->ne[0] == d_state);
|
||||
GGML_ASSERT(C->ne[1] == n_tokens);
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||
|
||||
result->op = GGML_OP_SSM_SCAN;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = s;
|
||||
result->src[1] = x;
|
||||
result->src[2] = dt;
|
||||
result->src[3] = A;
|
||||
result->src[4] = B;
|
||||
result->src[5] = C;
|
||||
result->src[6] = sq;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_win_part
|
||||
|
||||
struct ggml_tensor * ggml_win_part(
|
||||
|
@ -14771,6 +14877,257 @@ static void ggml_compute_forward_flash_attn_back(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_ssm_conv
|
||||
|
||||
static void ggml_compute_forward_ssm_conv_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // x
|
||||
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src2->ne[0]; // d_conv
|
||||
const int nr = src0->ne[1]; // d_inner
|
||||
const int n_t = src1->ne[1]; // n_tokens
|
||||
const int n_kv = src0->ne[2]; // max number of sequences in the batch
|
||||
|
||||
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||
// for use with the destination state offset between sequences
|
||||
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
if (n_kv > 1) {
|
||||
// multiple sequences means it's hard to know when it's the first time a state is read,
|
||||
// so copy them all over to the destination, just to be sure.
|
||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
||||
// can't use memcpy because of d_conv vs d_conv - 1
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||
// copy s0 to last (d_conv - 1) columns of s
|
||||
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
||||
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
|
||||
float * s0; // {d_conv - 1, d_inner, n_kv}
|
||||
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
||||
int ne0s0;
|
||||
|
||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
||||
|
||||
// avoid needing to copy the state for the first token
|
||||
if (i2 == 0) {
|
||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
|
||||
ne0s0 = src0->ne[0];
|
||||
} else {
|
||||
// the source is the last (d_conv - 1) columns of the destination
|
||||
s0 = s + 1;
|
||||
ne0s0 = nc;
|
||||
}
|
||||
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
// shift state left
|
||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
|
||||
}
|
||||
// insert x on the last column
|
||||
s[(nc - 1) + i1*nc] = x0[i1];
|
||||
}
|
||||
|
||||
// handle copies when there are multiple output states
|
||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
||||
int32_t seq = sq[i3];
|
||||
if (0 <= seq && seq < n_kv) {
|
||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
||||
memcpy(s1, s, nc*ir*sizeof(float));
|
||||
} else {
|
||||
// stop at negative or too big seq_ids
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// it seems a little faster when this is separate from the state shift
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
// rowwise dot product
|
||||
float sumf = 0.0f;
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
int i = i0 + i1*nc;
|
||||
sumf += s[i] * c[i];
|
||||
}
|
||||
x[i1] = sumf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_ssm_conv(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (dst->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_ssm_conv_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_ssm_scan
|
||||
|
||||
static void ggml_compute_forward_ssm_scan_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // s
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // x
|
||||
const struct ggml_tensor * src2 = dst->src[2]; // dt
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // A
|
||||
const struct ggml_tensor * src4 = dst->src[4]; // B
|
||||
const struct ggml_tensor * src5 = dst->src[5]; // C
|
||||
const struct ggml_tensor * src6 = dst->src[6]; // sq
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nc = src0->ne[0]; // d_state
|
||||
const int64_t nr = src0->ne[1]; // d_inner
|
||||
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
||||
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
// required for the dot product between s and C, and when copying the states
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||
// required for per-sequence offsets for states
|
||||
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
||||
// required to get correct offset for state destination (i.e. src1->nb[2])
|
||||
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
if (n_kv > 1) {
|
||||
// it's hard to know if the source states have already been copied
|
||||
// when there are multiple, so copy them already.
|
||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
||||
memcpy(s, s0, nc*ir*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
|
||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
||||
float * s0;
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
||||
|
||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
||||
|
||||
// avoid needing to copy the state for the first token
|
||||
if (i2 == 0) {
|
||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
|
||||
} else {
|
||||
// otherwise the source is the same as the destination
|
||||
s0 = s;
|
||||
}
|
||||
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
||||
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
||||
float x_dt = x[i1] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
// d_state
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
int i = i0 + i1*nc;
|
||||
// state = prev_state * dA + dB * x
|
||||
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
sumf += state * C[i0];
|
||||
s[i] = state;
|
||||
}
|
||||
y[i1] = sumf;
|
||||
}
|
||||
|
||||
// handle copies when there are multiple output states
|
||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
||||
int32_t seq = sq[i3];
|
||||
if (0 <= seq && seq < n_kv) {
|
||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
||||
memcpy(s1, s, nc*ir*sizeof(float));
|
||||
} else {
|
||||
// stop at negative or too big seq_ids
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_ssm_scan(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (dst->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_ssm_scan_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_win_part
|
||||
|
||||
static void ggml_compute_forward_win_part_f32(
|
||||
|
@ -15830,6 +16187,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
bool masked = t != 0;
|
||||
ggml_compute_forward_flash_attn_back(params, masked, tensor);
|
||||
} break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
{
|
||||
ggml_compute_forward_ssm_conv(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
ggml_compute_forward_ssm_scan(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
{
|
||||
ggml_compute_forward_win_part(params, tensor);
|
||||
|
@ -16884,6 +17249,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
GGML_ASSERT(false); // not supported
|
||||
} break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_UNARY:
|
||||
|
@ -17590,6 +17960,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
|
|
19
ggml.h
19
ggml.h
|
@ -472,6 +472,8 @@ extern "C" {
|
|||
GGML_OP_FLASH_ATTN,
|
||||
GGML_OP_FLASH_FF,
|
||||
GGML_OP_FLASH_ATTN_BACK,
|
||||
GGML_OP_SSM_CONV,
|
||||
GGML_OP_SSM_SCAN,
|
||||
GGML_OP_WIN_PART,
|
||||
GGML_OP_WIN_UNPART,
|
||||
GGML_OP_GET_REL_POS,
|
||||
|
@ -1728,6 +1730,23 @@ extern "C" {
|
|||
struct ggml_tensor * c0,
|
||||
struct ggml_tensor * c1);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * c,
|
||||
struct ggml_tensor * sq);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * s,
|
||||
struct ggml_tensor * x,
|
||||
struct ggml_tensor * dt,
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * sq);
|
||||
|
||||
// partition into non-overlapping windows with padding if needed
|
||||
// example:
|
||||
// a: 768 64 64 1
|
||||
|
|
|
@ -61,6 +61,12 @@ class Keys:
|
|||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
|
||||
class SSM:
|
||||
CONV_KERNEL = "{arch}.ssm.conv_kernel"
|
||||
INNER_SIZE = "{arch}.ssm.inner_size"
|
||||
STATE_SIZE = "{arch}.ssm.state_size"
|
||||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
LIST = "tokenizer.ggml.tokens"
|
||||
|
@ -113,6 +119,7 @@ class MODEL_ARCH(IntEnum):
|
|||
MINICPM = auto()
|
||||
GEMMA = auto()
|
||||
STARCODER2 = auto()
|
||||
MAMBA = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
|
@ -144,6 +151,13 @@ class MODEL_TENSOR(IntEnum):
|
|||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
SSM_IN = auto()
|
||||
SSM_CONV1D = auto()
|
||||
SSM_X = auto()
|
||||
SSM_DT = auto()
|
||||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
|
@ -171,6 +185,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.MINICPM: "minicpm",
|
||||
MODEL_ARCH.GEMMA: "gemma",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -202,6 +217,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
||||
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -543,6 +565,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.MAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_X,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_D,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -734,6 +769,12 @@ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
|
|||
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
|
||||
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
|
||||
|
||||
# SSM
|
||||
KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
|
||||
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
|
||||
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
|
||||
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
||||
|
||||
# tokenization
|
||||
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
||||
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
|
||||
|
|
|
@ -382,6 +382,18 @@ class GGUFWriter:
|
|||
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_conv_kernel(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_inner_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_state_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_time_step_rank(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
|
||||
|
||||
def add_tokenizer_model(self, model: str) -> None:
|
||||
self.add_string(Keys.Tokenizer.MODEL, model)
|
||||
|
||||
|
|
|
@ -20,6 +20,9 @@ class TensorNameMap:
|
|||
"wte", # gpt2
|
||||
"transformer.embd.wte", # phi2
|
||||
"model.tok_embeddings", # internlm2
|
||||
"model.embedding", # mamba-qbert
|
||||
"backbone.embedding", # mamba
|
||||
"backbone.embeddings", # mamba-hf
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
|
@ -44,7 +47,7 @@ class TensorNameMap:
|
|||
# Output
|
||||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
|
@ -61,6 +64,8 @@ class TensorNameMap:
|
|||
"language_model.encoder.final_layernorm", # persimmon
|
||||
"model.final_layernorm", # persimmon
|
||||
"lm_head.ln", # phi2
|
||||
"model.norm_f", # mamba-qbert
|
||||
"backbone.norm_f", # mamba
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
|
@ -86,6 +91,8 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.ln", # phi2
|
||||
"model.layers.layers.{bid}.norm", # plamo
|
||||
"model.layers.{bid}.attention_norm", # internlm2
|
||||
"model.layers.{bid}.norm", # mamba-qbert
|
||||
"backbone.layers.{bid}.norm", # mamba
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
|
@ -282,7 +289,42 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.LAYER_OUT_NORM: (
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"encoder.layers.{bid}.norm2", # nomic-bert
|
||||
)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj",
|
||||
"backbone.layers.{bid}.mixer.in_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_CONV1D: (
|
||||
"model.layers.{bid}.conv1d",
|
||||
"backbone.layers.{bid}.mixer.conv1d",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_X: (
|
||||
"model.layers.{bid}.x_proj",
|
||||
"backbone.layers.{bid}.mixer.x_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_DT: (
|
||||
"model.layers.{bid}.dt_proj",
|
||||
"backbone.layers.{bid}.mixer.dt_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_A: (
|
||||
"model.layers.{bid}.A_log",
|
||||
"backbone.layers.{bid}.mixer.A_log",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_D: (
|
||||
"model.layers.{bid}.D",
|
||||
"backbone.layers.{bid}.mixer.D",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
"model.layers.{bid}.out_proj",
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
),
|
||||
}
|
||||
|
||||
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||
|
|
4
llama.h
4
llama.h
|
@ -235,6 +235,7 @@ extern "C" {
|
|||
uint32_t seed; // RNG seed, -1 for random
|
||||
uint32_t n_ctx; // text context, 0 = from model
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
|
@ -376,6 +377,7 @@ extern "C" {
|
|||
|
||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_max_seq (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
|
@ -502,7 +504,7 @@ extern "C" {
|
|||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_rm(
|
||||
LLAMA_API bool llama_kv_cache_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue