From 90568a669644535316491e5cd073f0966677f1c0 Mon Sep 17 00:00:00 2001 From: mike dupont Date: Sat, 25 Nov 2023 11:13:45 -0500 Subject: [PATCH] now server has it --- binding.py | 21 ++++--- examples/server/server.cpp | 120 ++++++++++++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 10 deletions(-) diff --git a/binding.py b/binding.py index 668afd566..217dce684 100644 --- a/binding.py +++ b/binding.py @@ -14,8 +14,9 @@ llvmLibPath = "/usr/lib/llvm-15/lib/" cxxClientRoot = "/home/mdupont/experiments/llama.cpp/" fileList = [ - "ggml.cpp", - "llama.cpp" +# "ggml.cpp", +# "llama.cpp", + "examples/server/server.cpp", ] typeList = [ @@ -224,10 +225,11 @@ UNNAMED_STRUCT_DELIM = '::(unnamed struct' def traverse(node, namespace, main_file): # only scan the elements of the file we parsed - #print("FILE", node.location.file ) + if node.kind == clang.cindex.CursorKind.STRUCT_DECL or node.kind == clang.cindex.CursorKind.CLASS_DECL: fullStructName = "::".join([*namespace, node.displayname]) + print("#FILE", node.location.file ) print("REFL_TYPE(" + fullStructName + ")") structFields = [] @@ -247,14 +249,15 @@ def traverse(node, namespace, main_file): "type": struct_type, }) # replica read changes introduced duplicate get requests - if any(map(lambda op: op['name'] == fullStructName, opTypes)): - return + #if any(map(lambda op: op['name'] == fullStructName, opTypes)): + # return - opTypes.append({ - "name": fullStructName, - "fields": structFields, - }) + #opTypes.append({ + # "name": fullStructName, + # "fields": structFields, + #}) print("REFL_END") + if node.kind == clang.cindex.CursorKind.TYPE_ALIAS_DECL: fullStructName = "::".join([*namespace, node.displayname]) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 50f124b13..a42bba9b6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -24,6 +24,7 @@ #include #include #include +#include "print.hpp" #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 @@ -33,6 +34,9 @@ using json = nlohmann::json; +REFL_TYPE(std::less< ::nlohmann::detail::value_t>) +REFL_END + struct server_params { std::string hostname = "127.0.0.1"; @@ -41,6 +45,13 @@ struct server_params int32_t read_timeout = 600; int32_t write_timeout = 600; }; +REFL_TYPE(server_params) + REFL_FIELD(hostname) + REFL_FIELD(public_path) + REFL_FIELD(port) + REFL_FIELD(read_timeout) + REFL_FIELD(write_timeout) +REFL_END static bool server_verbose = false; @@ -157,6 +168,15 @@ struct task_server { bool embedding_mode = false; }; +REFL_TYPE(task_server) + REFL_FIELD(id) + REFL_FIELD(target_id) + REFL_FIELD(type) + REFL_FIELD(data) + REFL_FIELD(infill_mode) + REFL_FIELD(embedding_mode) +REFL_END + struct task_result { int id; bool stop; @@ -193,6 +213,18 @@ struct slot_params json input_suffix; }; +REFL_TYPE(slot_params) + REFL_FIELD(stream) + REFL_FIELD(cache_prompt) + REFL_FIELD(seed) + REFL_FIELD(n_keep) + REFL_FIELD(n_predict) + REFL_FIELD(antiprompt) + REFL_FIELD(input_prefix) + REFL_FIELD(input_suffix) +REFL_END + + struct slot_image { int32_t id; @@ -220,6 +252,17 @@ struct completion_token_output std::string text_to_send; }; +REFL_TYPE(completion_token_output) + REFL_FIELD(probs) + REFL_FIELD(tok) + REFL_FIELD(text_to_send) +REFL_END + +REFL_TYPE(completion_token_output::token_prob) + REFL_FIELD(tok) + REFL_FIELD(prob) +REFL_END + static size_t common_part(const std::vector &a, const std::vector &b) { size_t i; @@ -496,6 +539,51 @@ struct llama_client_slot } }; +//REFL_TYPE(llama_client_slot::llama_sampling_params) +//REFL_END + +REFL_TYPE(llama_client_slot) + REFL_FIELD(id) + REFL_FIELD(task_id) + REFL_FIELD(params) + REFL_FIELD(state) + REFL_FIELD(command) + REFL_FIELD(t_last_used) + REFL_FIELD(n_ctx) + REFL_FIELD(n_past) + REFL_FIELD(n_decoded) + REFL_FIELD(n_remaining) + REFL_FIELD(i_batch) + REFL_FIELD(num_prompt_tokens) + REFL_FIELD(num_prompt_tokens_processed) + REFL_FIELD(multibyte_pending) + REFL_FIELD(prompt) + REFL_FIELD(generated_text) + REFL_FIELD(sampled) + REFL_FIELD(cache_tokens) + REFL_FIELD(generated_token_probs) + REFL_FIELD(infill) + REFL_FIELD(embedding) + REFL_FIELD(has_next_token) + REFL_FIELD(truncated) + REFL_FIELD(stopped_eos) + REFL_FIELD(stopped_word) + REFL_FIELD(stopped_limit) + REFL_FIELD(oaicompat) + REFL_FIELD(oaicompat_model) + REFL_FIELD(stopping_word) + REFL_FIELD(sparams) + REFL_FIELD(ctx_sampling) + REFL_FIELD(images) + REFL_FIELD(sent_count) + REFL_FIELD(sent_token_probs_index) + REFL_FIELD(t_start_process_prompt) + REFL_FIELD(t_start_genereration) + REFL_FIELD(t_prompt_processing) + REFL_FIELD(t_token_generation) +REFL_END + + struct llama_server_context { llama_model *model = nullptr; @@ -878,7 +966,7 @@ struct llama_server_context all_slots_are_idle = false; LOG_TEE("slot %i is processing [task id: %i]\n", slot->id, slot->task_id); - + print_fields(*slot); return true; } @@ -1787,6 +1875,31 @@ struct llama_server_context } }; +REFL_TYPE(llama_server_context) + REFL_FIELD(model) + REFL_FIELD(ctx) + REFL_FIELD(clp_ctx) + REFL_FIELD(params) + REFL_FIELD(batch) + REFL_FIELD(multimodal) + REFL_FIELD(clean_kv_cache) + REFL_FIELD(all_slots_are_idle) + REFL_FIELD(add_bos_token) + REFL_FIELD(id_gen) + REFL_FIELD(n_ctx) + REFL_FIELD(system_need_update) + REFL_FIELD(system_prompt) + REFL_FIELD(system_tokens) + REFL_FIELD(name_user) + REFL_FIELD(name_assistant) + REFL_FIELD(slots) + REFL_FIELD(queue_tasks) + REFL_FIELD(queue_results) + REFL_FIELD(mutex_tasks) + REFL_FIELD(mutex_results) +REFL_END + + static void server_print_usage(const char *argv0, const gpt_params ¶ms, const server_params &sparams) { @@ -2497,6 +2610,11 @@ struct token_translator std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); } }; + +REFL_TYPE(token_translator) + REFL_FIELD(ctx) +REFL_END + static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, llama_client_slot *slot) { auto & gtps = slot->generated_token_probs;