From 8c453d1e4eca18ad2338c5d8bb9cc7b206ca6835 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 18 Sep 2023 23:02:00 +0800 Subject: [PATCH] added grammar sampling --- CMakeLists.txt | 4 +- Makefile | 22 +-- build-info.h | 2 + expose.h | 1 + gpttype_adapter.cpp | 455 ++++++++++++++++++++++++++------------------ koboldcpp.py | 12 +- 6 files changed, 291 insertions(+), 205 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c7a9638d1..6206737c2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -380,7 +380,9 @@ set_target_properties(ggml_v2 PROPERTIES POSITION_INDEPENDENT_CODE ON) add_library(common2 common/common.cpp - common/common.h) + common/common.h + common/grammar-parser.h + common/grammar-parser.cpp) target_include_directories(common2 PUBLIC . ./otherarch ./otherarch/tools ./examples ./common) target_compile_features(common2 PUBLIC cxx_std_11) # don't bump target_link_libraries(common2 PRIVATE ggml ${LLAMA_EXTRA_LIBS}) diff --git a/Makefile b/Makefile index cadbeedb1..440ff6112 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -default: koboldcpp_default koboldcpp_failsafe koboldcpp_openblas koboldcpp_noavx2 koboldcpp_clblast koboldcpp_cublas koboldcpp_hipblas +default: koboldcpp_default koboldcpp_failsafe koboldcpp_openblas koboldcpp_noavx2 koboldcpp_clblast koboldcpp_cublas koboldcpp_hipblas tools: quantize_gpt2 quantize_gptj quantize_llama quantize_neox quantize_mpt dev: koboldcpp_openblas dev2: koboldcpp_clblast @@ -211,15 +211,15 @@ endif # LLAMA_CUDA_FORCE_DMMV ggml-cuda.o: HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) \ -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) \ -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) \ - -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) + -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) ggml_v2-cuda.o: HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) \ -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) \ -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) \ - -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) + -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) ggml_v2-cuda-legacy.o: HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) \ -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) \ -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) \ - -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) + -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(CXX) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< ggml_v2-cuda.o: otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h @@ -426,19 +426,19 @@ gguf: examples/gguf/gguf.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) #generated libraries -koboldcpp_default: ggml.o ggml_v2.o ggml_v1.o expose.o common.o gpttype_adapter.o k_quants.o ggml-alloc.o $(OBJS) +koboldcpp_default: ggml.o ggml_v2.o ggml_v1.o expose.o common.o gpttype_adapter.o k_quants.o ggml-alloc.o grammar-parser.o $(OBJS) $(DEFAULT_BUILD) -koboldcpp_openblas: ggml_openblas.o ggml_v2_openblas.o ggml_v1.o expose.o common.o gpttype_adapter.o k_quants.o ggml-alloc.o $(OBJS) +koboldcpp_openblas: ggml_openblas.o ggml_v2_openblas.o ggml_v1.o expose.o common.o gpttype_adapter.o k_quants.o ggml-alloc.o grammar-parser.o $(OBJS) $(OPENBLAS_BUILD) -koboldcpp_failsafe: ggml_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_failsafe.o ggml-alloc.o $(OBJS) +koboldcpp_failsafe: ggml_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_failsafe.o ggml-alloc.o grammar-parser.o $(OBJS) $(FAILSAFE_BUILD) -koboldcpp_noavx2: ggml_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_noavx2.o ggml-alloc.o $(OBJS) +koboldcpp_noavx2: ggml_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_noavx2.o ggml-alloc.o grammar-parser.o $(OBJS) $(NOAVX2_BUILD) -koboldcpp_clblast: ggml_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o common.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o k_quants.o ggml-alloc.o $(OBJS) +koboldcpp_clblast: ggml_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o common.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o k_quants.o ggml-alloc.o grammar-parser.o $(OBJS) $(CLBLAST_BUILD) -koboldcpp_cublas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o $(CUBLAS_OBJS) $(OBJS) +koboldcpp_cublas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o grammar-parser.o $(CUBLAS_OBJS) $(OBJS) $(CUBLAS_BUILD) -koboldcpp_hipblas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o $(HIP_OBJS) $(OBJS) +koboldcpp_hipblas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o grammar-parser.o $(HIP_OBJS) $(OBJS) $(HIPBLAS_BUILD) quantize_llama: examples/quantize/quantize.cpp ggml.o llama.o k_quants.o ggml-alloc.o diff --git a/build-info.h b/build-info.h index 70a4db739..6428ea975 100644 --- a/build-info.h +++ b/build-info.h @@ -3,5 +3,7 @@ #define BUILD_NUMBER 999 #define BUILD_COMMIT "KOBOLDCPP" +#define BUILD_COMPILER "KCPP" +#define BUILD_TARGET "KCPP" #endif // BUILD_INFO_H diff --git a/expose.h b/expose.h index a6768980d..535e11374 100644 --- a/expose.h +++ b/expose.h @@ -72,6 +72,7 @@ struct generation_inputs const bool unban_tokens_rt; const char * stop_sequence[stop_token_max]; const bool stream_sse; + const char * grammar; }; struct generation_outputs { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 69603004f..321000d4e 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -11,6 +11,7 @@ #include #include "model_adapter.h" #include "otherarch.h" +#include "grammar-parser.h" //for easier compilation //concat source files into one file for compilation purposes @@ -41,10 +42,14 @@ int last_token_count = 0; stop_reason last_stop_reason = stop_reason::INVALID; std::vector generated_tokens; +llama_grammar * grammar = nullptr; //currently used grammar +grammar_parser::parse_state parsed_grammar; + //return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) static FileFormat file_format = FileFormat::BADFORMAT; static gpt_vocab vocab; +static int32_t n_vocab = 0; static gptj_v1_model gptj_ctx_v1; static gptj_v2_model gptj_ctx_v2; @@ -61,6 +66,7 @@ static mpt_model mpt_ctx_v3; static rwkv_v2_context * rwkv_ctx_v2; static rwkv_context * rwkv_ctx_v3; + static llama_v2_context * llama_ctx_v2; static llama_v3_context * llama_ctx_v3; static llama_context * llama_ctx_v4; @@ -115,6 +121,133 @@ inline bool LogitsDuplicated(std::vector & arr1, std::vector & arr } +static std::string FileFormatTokenizeID(int id, FileFormat file_format) +{ + if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2) + { + return std::string(llama_v2_token_to_str(llama_ctx_v2, id)); + } + else if (file_format == FileFormat::GGJT_3) + { + return std::string(llama_v3_token_to_str(llama_ctx_v3, id)); + } + else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) + { + return std::string(llama_token_to_str(llama_ctx_v4, id)); + } + else + { + return vocab.id_to_token[id]; + } +} + +static void TokenizeString(const std::string & str_to_tokenize, std::vector & output_tokens, FileFormat file_format) +{ + if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) + { + if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 ) + { + output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true); + } + else if (file_format == FileFormat::GGML) + { + output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true); + } + else if (file_format == FileFormat::GGJT_3) + { + output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true); + } + else + { + output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true); + } + } + else + { + // tokenize the prompt + output_tokens = ::gpt_tokenize(vocab, str_to_tokenize); + } +} +static int GetEosID(FileFormat file_format, int32_t n_vocab) +{ + unsigned int eosID = 0; + + if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) + { + if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) + { + eosID = llama_token_eos(llama_ctx_v4); + } + else if(file_format == FileFormat::GGJT_3) + { + eosID = llama_v3_token_eos(); + } + else + { + eosID = llama_v3_token_eos(); + } + } + else + { + if (file_format == FileFormat::GPT2_1 || + file_format == FileFormat::GPT2_2 || + file_format == FileFormat::GPT2_3 || + file_format == FileFormat::GPT2_4 || + file_format == FileFormat::GPTJ_1 || + file_format == FileFormat::GPTJ_2 || + file_format == FileFormat::GPTJ_3 || + file_format == FileFormat::GPTJ_4 || + file_format == FileFormat::GPTJ_5) + { + eosID = 50256; + if (n_vocab <= eosID) + { + //special case, starcoder models use ID 0 for EOS + eosID = 0; + } + } + + if (file_format == FileFormat::RWKV_1 || + file_format == FileFormat::RWKV_2 || + file_format == FileFormat::NEOX_1 || + file_format == FileFormat::NEOX_2 || + file_format == FileFormat::NEOX_3 || + file_format == FileFormat::NEOX_4 || + file_format == FileFormat::NEOX_5 || + file_format == FileFormat::NEOX_6 || + file_format == FileFormat::NEOX_7 || + file_format == FileFormat::MPT_1) + { + eosID = 0; + } + } + return eosID; +} +static float LowestLogit(const std::vector & logits) +{ + int topid = std::min_element(logits.begin(), logits.end()) - logits.begin(); + float v = logits[topid]; + return (v < 0 ? (v-8) : 0); +} +static float LowestLogit(const float *logits, size_t size) +{ + if (size == 0) { + // Handle the case of an empty array + return 0.0; + } + int topid = std::min_element(logits, logits + size) - logits; + float v = logits[topid]; + return (v < 0 ? (v-8) : 0); +} + +static std::string RemoveBell(const std::string & input) //removes the bell character +{ + std::string word2; + std::remove_copy(input.begin(), input.end(), std::back_inserter(word2), '\a'); + return word2; +} + + llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng) { llama_sample_softmax(nullptr, candidates); @@ -256,8 +389,47 @@ void sample_temperature(llama_token_data_array * candidates_p, float temp) } } +void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_array * candidates, const struct llama_grammar * grammar) { + + const int64_t t_start_sample_us = ggml_time_us(); + + bool allow_eos = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eos = true; + break; + } + } + + const llama_token eos = GetEosID(file_format,n_vocab); + + std::vector, llama_partial_utf8>> candidates_decoded; + std::vector candidates_grammar; + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const std::string piece = FileFormatTokenizeID(id,file_format); + if (id == eos) { + if (!allow_eos) { + candidates->data[i].logit = -INFINITY; + } + } else if (piece.empty() || piece[0] == 0) { + candidates->data[i].logit = -INFINITY; + } else { + candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (const auto & reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + +} + int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng, -int mirostat, float mirostat_tau, float mirostat_eta, const std::vector & sampler_order) +int mirostat, float mirostat_tau, float mirostat_eta, const std::vector & sampler_order, llama_grammar * grammar) { int id = 0; std::vector candidates; @@ -268,6 +440,10 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vectorstacks) { + if (stack.empty()) { + return; + } + } + GGML_ASSERT(false); } - else if (file_format == FileFormat::GGJT_3) - { - return std::string(llama_v3_token_to_str(llama_ctx_v3, id)); - } - else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) - { - return std::string(llama_token_to_str(llama_ctx_v4, id)); - } - else - { - return vocab.id_to_token[id]; + const std::string piece = FileFormatTokenizeID(token,file_format); //llama_token_to_str(ctx, token); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); + const auto & code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); } + grammar->partial_utf8 = decoded.second; + GGML_ASSERT(!grammar->stacks.empty()); } -static void TokenizeString(const std::string & str_to_tokenize, std::vector & output_tokens, FileFormat file_format) +static void load_grammar(const std::string & gammarstr) { - if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) + if(grammar!=nullptr) //on demand free when next grammar is loaded { - if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 ) - { - output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true); - } - else if (file_format == FileFormat::GGML) - { - output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true); - } - else if (file_format == FileFormat::GGJT_3) - { - output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true); - } - else - { - output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true); - } + llama_grammar_free(grammar); + grammar = nullptr; } - else - { - // tokenize the prompt - output_tokens = ::gpt_tokenize(vocab, str_to_tokenize); - } -} -static float LowestLogit(const std::vector & logits) -{ - int topid = std::min_element(logits.begin(), logits.end()) - logits.begin(); - float v = logits[topid]; - return (v < 0 ? (v-8) : 0); -} -static float LowestLogit(const float *logits, size_t size) -{ - if (size == 0) { - // Handle the case of an empty array - return 0.0; + if (!gammarstr.empty()) { + parsed_grammar = grammar_parser::parse(gammarstr.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + printf("\nIgnored invalid grammar sampler."); + return; + } + grammar_parser::print_grammar(stderr, parsed_grammar); + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - int topid = std::min_element(logits, logits + size) - logits; - float v = logits[topid]; - return (v < 0 ? (v-8) : 0); -} - -static std::string RemoveBell(const std::string & input) //removes the bell character -{ - std::string word2; - std::remove_copy(input.begin(), input.end(), std::back_inserter(word2), '\a'); - return word2; } ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta) @@ -522,6 +670,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } } + n_vocab = llama_v2_n_vocab(llama_ctx_v2); + //determine mem per token const std::vector tmp = {1, 2, 3, 4}; llama_v2_eval(llama_ctx_v2, tmp.data(), tmp.size(), 0, params.n_threads); @@ -587,6 +737,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } } + n_vocab = llama_v3_n_vocab(llama_ctx_v3); + //determine mem per token const std::vector tmp = {1, 2, 3, 4}; auto er = llama_v3_eval(llama_ctx_v3, tmp.data(), tmp.size(), 0, params.n_threads); @@ -663,6 +815,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } } + n_vocab = llama_n_vocab(llama_ctx_v4); + //determine mem per token const std::vector tmp = {1, 2, 3, 4}; auto er = llama_eval(llama_ctx_v4, tmp.data(), tmp.size(), 0, params.n_threads); @@ -720,6 +874,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in printf("\nRWKV Vocab: %u\n", vocabsiz); logits.resize(vocabsiz); + n_vocab = vocab.id_to_token.size(); //handled seperately + if (file_format == FileFormat::RWKV_1) { n_batch = 1; @@ -790,6 +946,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in printf("\nTensor Transposition Detected! Retrying GPT-2 model loading..."); return res; } + + n_vocab = gpt2_ctx_v1.hparams.n_vocab; + // determine the required inference memory per token: legacy_gpt2_eval(gpt2_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); return ModelLoadResult::SUCCESS; @@ -809,6 +968,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in printf("\nTensor Transposition Detected! Retrying GPT-2 model loading..."); return res; } + + n_vocab = gpt2_ctx_v3.hparams.n_vocab; + // determine the required inference memory per token: gpt2_eval(gpt2_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch); return ModelLoadResult::SUCCESS; @@ -829,6 +991,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in printf("\nTensor Transposition Detected! Retrying GPT-2 model loading..."); return res; } + + n_vocab = gpt2_ctx_v2.hparams.n_vocab; + // determine the required inference memory per token: gpt2_v2_eval(gpt2_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); return ModelLoadResult::SUCCESS; @@ -847,6 +1012,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in printf("\nTensor Transposition Detected! Retrying GPT-J model loading..."); return res; } + + n_vocab = gptj_ctx_v1.hparams.n_vocab; + // determine the required inference memory per token: legacy_gptj_eval(gptj_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); @@ -876,6 +1044,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return loadresult; } + n_vocab = gptj_ctx_v3.hparams.n_vocab; + // determine the required inference memory per token: gptj_eval(gptj_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch); @@ -912,6 +1082,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return loadresult; } + n_vocab = gptj_ctx_v2.hparams.n_vocab; + // determine the required inference memory per token: gptj_v2_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); @@ -948,6 +1120,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return res; } + n_vocab = neox_ctx_v3.hparams.n_vocab; + // determine the required inference memory per token: gpt_neox_eval(neox_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch); @@ -970,6 +1144,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return res; } + n_vocab = neox_ctx_v2.hparams.n_vocab; + // determine the required inference memory per token: gpt_neox_v2_eval(neox_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); @@ -1005,6 +1181,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return ModelLoadResult::FAIL; } + n_vocab = mpt_ctx_v3.hparams.n_vocab; + // determine the required inference memory per token: mpt_eval(mpt_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token, use_scratch); return ModelLoadResult::SUCCESS; @@ -1084,6 +1262,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o generation_finished = false; // Set current generation status generated_tokens.clear(); // New Generation, new tokens + std::string grammarstr = inputs.grammar; + load_grammar(grammarstr); + if (params.repeat_last_n < 1) { params.repeat_last_n = 1; @@ -1193,59 +1374,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o timer_start(); double time1 = 0, time2 = 0; - int32_t n_vocab = 0; - if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2) + if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { - n_vocab = llama_v2_n_vocab(llama_ctx_v2); - } - else if(file_format == FileFormat::GGJT_3) - { - n_vocab = llama_v3_n_vocab(llama_ctx_v3); - } - else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) - { - n_vocab = llama_n_vocab(llama_ctx_v4); - } - else if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2) - { - n_vocab = gptj_ctx_v1.hparams.n_vocab; - } - else if(file_format == FileFormat::GPTJ_3 || file_format==FileFormat::GPTJ_4) - { - n_vocab = gptj_ctx_v2.hparams.n_vocab; - } - else if(file_format==FileFormat::GPTJ_5) - { - n_vocab = gptj_ctx_v3.hparams.n_vocab; - } - else if(file_format == FileFormat::GPT2_1) - { - n_vocab = gpt2_ctx_v1.hparams.n_vocab; - } - else if(file_format == FileFormat::GPT2_2 || file_format==FileFormat::GPT2_3) - { - n_vocab = gpt2_ctx_v2.hparams.n_vocab; - } - else if(file_format==FileFormat::GPT2_4) - { - n_vocab = gpt2_ctx_v3.hparams.n_vocab; - } - else if(file_format == FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5) - { - n_vocab = neox_ctx_v2.hparams.n_vocab; - } - else if( file_format==FileFormat::NEOX_6|| file_format==FileFormat::NEOX_7) - { - n_vocab = neox_ctx_v3.hparams.n_vocab; - } - else if( file_format==FileFormat::MPT_1) - { - n_vocab = mpt_ctx_v3.hparams.n_vocab; - } - else if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) - { - n_vocab = vocab.id_to_token.size(); //handled seperately if(n_past==0) { if(file_format == FileFormat::RWKV_1) @@ -1276,9 +1407,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } } } - else + + if(n_vocab<=0) { - printf("Bad format!"); + printf("\nWarning! n_vocab is invalid, maybe bad format!"); } //prepare banned tokens @@ -1459,107 +1591,52 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } } - unsigned int eosID = 0; + unsigned int eosID = GetEosID(file_format, n_vocab); float * logitsPtr; + float lowestLogit = 0; int btsize = banned_token_ids.size(); if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) { if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON) { logitsPtr = llama_get_logits(llama_ctx_v4); - eosID = llama_token_eos(llama_ctx_v4); } else if(file_format == FileFormat::GGJT_3) { logitsPtr = llama_v3_get_logits(llama_ctx_v3); - eosID = llama_v3_token_eos(); } else { logitsPtr = llama_v2_get_logits(llama_ctx_v2); - eosID = llama_v3_token_eos(); - } - - float lowestLogit = LowestLogit(logitsPtr,n_vocab); - if (!unbanTokens && !inputs.unban_tokens_rt) - { - // set the logit of the eos token (2) to -INF to avoid sampling it - logitsPtr[eosID] = lowestLogit; - } - - if(btsize>0) - { - for(int t=0;t0) + { + for(int t=0;t eosID) - { - logits[eosID] = lowestLogit; - } - else - { - //special case, starcoder models use ID 0 for EOS - if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4) - { - eosID = 0; - logits[eosID] = lowestLogit; - - } - } - } - - // set the logit of the eos token (0) to minimum to avoid sampling it - if (file_format == FileFormat::RWKV_1 || - file_format == FileFormat::RWKV_2 || - file_format == FileFormat::NEOX_1 || - file_format == FileFormat::NEOX_2 || - file_format == FileFormat::NEOX_3 || - file_format == FileFormat::NEOX_4 || - file_format == FileFormat::NEOX_5 || - file_format == FileFormat::NEOX_6 || - file_format == FileFormat::NEOX_7 || - file_format == FileFormat::MPT_1) - { - eosID = 0; - logits[eosID] = lowestLogit; - } - } - - if(btsize>0) - { - for (int t = 0; t < btsize; ++t) - { - logits[banned_token_ids[t]] = lowestLogit; - } + logitsPtr[banned_token_ids[t]]=lowestLogit; } } id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_a, top_p, typical_p, tfs_z, temp, rng, - params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order); + params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order, grammar); + + if (grammar != nullptr) { + grammar_accept_token(file_format, n_vocab, grammar, id); + } last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); diff --git a/koboldcpp.py b/koboldcpp.py index 099985240..53b0cf5a4 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -63,7 +63,8 @@ class generation_inputs(ctypes.Structure): ("sampler_len", ctypes.c_int), ("unban_tokens_rt", ctypes.c_bool), ("stop_sequence", ctypes.c_char_p * stop_token_max), - ("stream_sse", ctypes.c_bool)] + ("stream_sse", ctypes.c_bool), + ("grammar", ctypes.c_char_p)] class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -277,7 +278,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False): +def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False, grammar=''): global maxctx, args inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) @@ -299,6 +300,7 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_ inputs.rep_pen = rep_pen inputs.rep_pen_range = rep_pen_range inputs.stream_sse = stream_sse + inputs.grammar = grammar.encode("UTF-8") inputs.unban_tokens_rt = not use_default_badwordsids if args.usemirostat and args.usemirostat[0]>0: inputs.mirostat = int(args.usemirostat[0]) @@ -399,7 +401,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): seed=genparams.get('sampler_seed', -1), stop_sequence=genparams.get('stop_sequence', []), use_default_badwordsids=genparams.get('use_default_badwordsids', True), - stream_sse=stream_flag) + stream_sse=stream_flag, + grammar=genparams.get('grammar', '')) else: return generate(prompt=newprompt, @@ -420,7 +423,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): seed=genparams.get('sampler_seed', -1), stop_sequence=genparams.get('stop_sequence', []), use_default_badwordsids=genparams.get('use_default_badwordsids', True), - stream_sse=stream_flag) + stream_sse=stream_flag, + grammar=genparams.get('grammar', '')) recvtxt = "" if stream_flag: