diff --git a/Makefile b/Makefile index 1623bbba0..f7cf21d5c 100644 --- a/Makefile +++ b/Makefile @@ -357,7 +357,7 @@ expose.o: expose.cpp expose.h $(CXX) $(CXXFLAGS) -c $< -o $@ # idiotic "for easier compilation" -GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp llama.cpp otherarch/utils.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml.h ggml-cuda.h llama.h otherarch/llama-util.h +GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp llama.cpp otherarch/utils.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml.h ggml-cuda.h llama.h otherarch/llama-util.h gpttype_adapter_failsafe.o: $(GPTTYPE_ADAPTER) $(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) -c $< -o $@ gpttype_adapter.o: $(GPTTYPE_ADAPTER) diff --git a/expose.cpp b/expose.cpp index febfcba45..d515f24e1 100644 --- a/expose.cpp +++ b/expose.cpp @@ -240,4 +240,10 @@ extern "C" bool abort_generate() { return gpttype_generate_abort(); } + + int token_count(const char * input) + { + std::string inputstr = input; + return gpttype_token_count(inputstr); + } } diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 9eb7ede2c..35c87d79b 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -338,6 +338,36 @@ static std::string FileFormatTokenizeID(int id, FileFormat file_format) } } +static std::vector TokenizeString(const std::string & str_to_tokenize, FileFormat file_format) +{ + std::vector tokvec; + if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA) + { + if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 ) + { + tokvec = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true); + } + else if (file_format == FileFormat::GGML) + { + tokvec = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true); + } + else if (file_format == FileFormat::GGJT_3) + { + tokvec = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true); + } + else + { + tokvec = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true); + } + } + else + { + // tokenize the prompt + tokvec = ::gpt_tokenize(vocab, str_to_tokenize); + } + return tokvec; +} + static std::string RemoveBell(const std::string & input) //removes the bell character { std::string word2; @@ -965,6 +995,21 @@ bool gpttype_generate_abort() return true; } +int gpttype_token_count(const std::string & input) +{ + if(debugmode==1) + { + printf("\nFileFormat: %d, Tokenizing: %s",file_format ,input.c_str()); + } + auto toks = TokenizeString(input, file_format); + int tokcount = toks.size(); + if(debugmode==1) + { + printf("\nTokens Counted: %d\n",tokcount); + } + return tokcount; +} + const std::string & gpttype_get_pending_output() { return concat_output; @@ -1018,32 +1063,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } // tokenize the prompt - std::vector embd_inp; - - if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA) - { - if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 ) - { - embd_inp = ::llama_v2_tokenize(llama_ctx_v2, params.prompt, true); - } - else if (file_format == FileFormat::GGML) - { - embd_inp = ::legacy_llama_v2_tokenize(llama_ctx_v2, params.prompt, true); - } - else if (file_format == FileFormat::GGJT_3) - { - embd_inp = ::llama_v3_tokenize(llama_ctx_v3, params.prompt, true); - } - else - { - embd_inp = ::llama_tokenize(llama_ctx_v4, params.prompt, true); - } - } - else - { - // tokenize the prompt - embd_inp = ::gpt_tokenize(vocab, params.prompt); - } + std::vector embd_inp = TokenizeString(params.prompt, file_format); //truncate to front of the prompt if its too long int32_t nctx = params.n_ctx; diff --git a/koboldcpp.py b/koboldcpp.py index a5ada2677..43a0590a2 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -178,6 +178,7 @@ def init_library(): handle.get_last_token_count.restype = ctypes.c_int handle.get_last_stop_reason.restype = ctypes.c_int handle.abort_generate.restype = ctypes.c_bool + handle.token_count.restype = ctypes.c_int handle.get_pending_output.restype = ctypes.c_char_p def load_model(model_filename): @@ -528,6 +529,22 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): kai_sse_stream_flag = False self.path = self.path.rstrip('/') + if self.path.endswith(('/api/extra/tokencount')): + try: + genparams = json.loads(body) + countprompt = genparams.get('prompt', "") + count = handle.token_count(countprompt.encode("UTF-8")) + self.send_response(200) + self.end_headers() + self.wfile.write(json.dumps({"value": count}).encode()) + + except ValueError as e: + utfprint("Count Tokens - Body Error: " + str(e)) + self.send_response(400) + self.end_headers() + self.wfile.write(json.dumps({"value": -1}).encode()) + return + if self.path.endswith('/api/extra/abort'): ag = handle.abort_generate() self.send_response(200) @@ -831,7 +848,7 @@ def show_new_gui(): debugmode = ctk.IntVar() lowvram_var = ctk.IntVar() - mmq_var = ctk.IntVar() + mmq_var = ctk.IntVar(value=1) blas_threads_var = ctk.StringVar() blas_size_var = ctk.IntVar() diff --git a/model_adapter.h b/model_adapter.h index f4e8a7034..4a4f9a540 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -62,6 +62,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output); bool gpttype_generate_abort(); const std::string & gpttype_get_pending_output(); +int gpttype_token_count(const std::string & input); void timer_start(); double timer_check();