added grammar sampling

This commit is contained in:
Concedo 2023-09-18 23:02:00 +08:00
parent 951614bfc6
commit 8c453d1e4e
6 changed files with 291 additions and 205 deletions

View file

@ -380,7 +380,9 @@ set_target_properties(ggml_v2 PROPERTIES POSITION_INDEPENDENT_CODE ON)
add_library(common2
common/common.cpp
common/common.h)
common/common.h
common/grammar-parser.h
common/grammar-parser.cpp)
target_include_directories(common2 PUBLIC . ./otherarch ./otherarch/tools ./examples ./common)
target_compile_features(common2 PUBLIC cxx_std_11) # don't bump
target_link_libraries(common2 PRIVATE ggml ${LLAMA_EXTRA_LIBS})

View file

@ -426,19 +426,19 @@ gguf: examples/gguf/gguf.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
#generated libraries
koboldcpp_default: ggml.o ggml_v2.o ggml_v1.o expose.o common.o gpttype_adapter.o k_quants.o ggml-alloc.o $(OBJS)
koboldcpp_default: ggml.o ggml_v2.o ggml_v1.o expose.o common.o gpttype_adapter.o k_quants.o ggml-alloc.o grammar-parser.o $(OBJS)
$(DEFAULT_BUILD)
koboldcpp_openblas: ggml_openblas.o ggml_v2_openblas.o ggml_v1.o expose.o common.o gpttype_adapter.o k_quants.o ggml-alloc.o $(OBJS)
koboldcpp_openblas: ggml_openblas.o ggml_v2_openblas.o ggml_v1.o expose.o common.o gpttype_adapter.o k_quants.o ggml-alloc.o grammar-parser.o $(OBJS)
$(OPENBLAS_BUILD)
koboldcpp_failsafe: ggml_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_failsafe.o ggml-alloc.o $(OBJS)
koboldcpp_failsafe: ggml_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_failsafe.o ggml-alloc.o grammar-parser.o $(OBJS)
$(FAILSAFE_BUILD)
koboldcpp_noavx2: ggml_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_noavx2.o ggml-alloc.o $(OBJS)
koboldcpp_noavx2: ggml_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_noavx2.o ggml-alloc.o grammar-parser.o $(OBJS)
$(NOAVX2_BUILD)
koboldcpp_clblast: ggml_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o common.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o k_quants.o ggml-alloc.o $(OBJS)
koboldcpp_clblast: ggml_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o common.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o k_quants.o ggml-alloc.o grammar-parser.o $(OBJS)
$(CLBLAST_BUILD)
koboldcpp_cublas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o $(CUBLAS_OBJS) $(OBJS)
koboldcpp_cublas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o grammar-parser.o $(CUBLAS_OBJS) $(OBJS)
$(CUBLAS_BUILD)
koboldcpp_hipblas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o $(HIP_OBJS) $(OBJS)
koboldcpp_hipblas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o grammar-parser.o $(HIP_OBJS) $(OBJS)
$(HIPBLAS_BUILD)
quantize_llama: examples/quantize/quantize.cpp ggml.o llama.o k_quants.o ggml-alloc.o

View file

@ -3,5 +3,7 @@
#define BUILD_NUMBER 999
#define BUILD_COMMIT "KOBOLDCPP"
#define BUILD_COMPILER "KCPP"
#define BUILD_TARGET "KCPP"
#endif // BUILD_INFO_H

View file

@ -72,6 +72,7 @@ struct generation_inputs
const bool unban_tokens_rt;
const char * stop_sequence[stop_token_max];
const bool stream_sse;
const char * grammar;
};
struct generation_outputs
{

View file

@ -11,6 +11,7 @@
#include <mutex>
#include "model_adapter.h"
#include "otherarch.h"
#include "grammar-parser.h"
//for easier compilation
//concat source files into one file for compilation purposes
@ -41,10 +42,14 @@ int last_token_count = 0;
stop_reason last_stop_reason = stop_reason::INVALID;
std::vector<std::string> generated_tokens;
llama_grammar * grammar = nullptr; //currently used grammar
grammar_parser::parse_state parsed_grammar;
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
static FileFormat file_format = FileFormat::BADFORMAT;
static gpt_vocab vocab;
static int32_t n_vocab = 0;
static gptj_v1_model gptj_ctx_v1;
static gptj_v2_model gptj_ctx_v2;
@ -61,6 +66,7 @@ static mpt_model mpt_ctx_v3;
static rwkv_v2_context * rwkv_ctx_v2;
static rwkv_context * rwkv_ctx_v3;
static llama_v2_context * llama_ctx_v2;
static llama_v3_context * llama_ctx_v3;
static llama_context * llama_ctx_v4;
@ -115,6 +121,133 @@ inline bool LogitsDuplicated(std::vector<float> & arr1, std::vector<float> & arr
}
static std::string FileFormatTokenizeID(int id, FileFormat file_format)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
return std::string(llama_v2_token_to_str(llama_ctx_v2, id));
}
else if (file_format == FileFormat::GGJT_3)
{
return std::string(llama_v3_token_to_str(llama_ctx_v3, id));
}
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
return std::string(llama_token_to_str(llama_ctx_v4, id));
}
else
{
return vocab.id_to_token[id];
}
}
static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
{
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGML)
{
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGJT_3)
{
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true);
}
else
{
output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true);
}
}
else
{
// tokenize the prompt
output_tokens = ::gpt_tokenize(vocab, str_to_tokenize);
}
}
static int GetEosID(FileFormat file_format, int32_t n_vocab)
{
unsigned int eosID = 0;
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
eosID = llama_token_eos(llama_ctx_v4);
}
else if(file_format == FileFormat::GGJT_3)
{
eosID = llama_v3_token_eos();
}
else
{
eosID = llama_v3_token_eos();
}
}
else
{
if (file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPT2_2 ||
file_format == FileFormat::GPT2_3 ||
file_format == FileFormat::GPT2_4 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::GPTJ_3 ||
file_format == FileFormat::GPTJ_4 ||
file_format == FileFormat::GPTJ_5)
{
eosID = 50256;
if (n_vocab <= eosID)
{
//special case, starcoder models use ID 0 for EOS
eosID = 0;
}
}
if (file_format == FileFormat::RWKV_1 ||
file_format == FileFormat::RWKV_2 ||
file_format == FileFormat::NEOX_1 ||
file_format == FileFormat::NEOX_2 ||
file_format == FileFormat::NEOX_3 ||
file_format == FileFormat::NEOX_4 ||
file_format == FileFormat::NEOX_5 ||
file_format == FileFormat::NEOX_6 ||
file_format == FileFormat::NEOX_7 ||
file_format == FileFormat::MPT_1)
{
eosID = 0;
}
}
return eosID;
}
static float LowestLogit(const std::vector<float> & logits)
{
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static float LowestLogit(const float *logits, size_t size)
{
if (size == 0) {
// Handle the case of an empty array
return 0.0;
}
int topid = std::min_element(logits, logits + size) - logits;
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static std::string RemoveBell(const std::string & input) //removes the bell character
{
std::string word2;
std::remove_copy(input.begin(), input.end(), std::back_inserter(word2), '\a');
return word2;
}
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
{
llama_sample_softmax(nullptr, candidates);
@ -256,8 +389,47 @@ void sample_temperature(llama_token_data_array * candidates_p, float temp)
}
}
void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
const int64_t t_start_sample_us = ggml_time_us();
bool allow_eos = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
allow_eos = true;
break;
}
}
const llama_token eos = GetEosID(file_format,n_vocab);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = FileFormatTokenizeID(id,file_format);
if (id == eos) {
if (!allow_eos) {
candidates->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
}
}
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
for (const auto & reject : rejects) {
candidates->data[reject.index].logit = -INFINITY;
}
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order)
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar)
{
int id = 0;
std::vector<llama_token_data> candidates;
@ -268,6 +440,10 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (grammar != nullptr) {
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
}
if (mirostat == 1 || mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau;
@ -321,76 +497,48 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
return id;
}
static std::string FileFormatTokenizeID(int id, FileFormat file_format)
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
return std::string(llama_v2_token_to_str(llama_ctx_v2, id));
if (token == GetEosID(file_format,n_vocab)) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return;
}
}
GGML_ASSERT(false);
}
else if (file_format == FileFormat::GGJT_3)
{
return std::string(llama_v3_token_to_str(llama_ctx_v3, id));
}
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
return std::string(llama_token_to_str(llama_ctx_v4, id));
}
else
{
return vocab.id_to_token[id];
const std::string piece = FileFormatTokenizeID(token,file_format); //llama_token_to_str(ctx, token);
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
}
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());
}
static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format)
static void load_grammar(const std::string & gammarstr)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
if(grammar!=nullptr) //on demand free when next grammar is loaded
{
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
{
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGML)
{
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, true);
}
else if (file_format == FileFormat::GGJT_3)
{
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, true);
}
else
{
output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true);
}
llama_grammar_free(grammar);
grammar = nullptr;
}
else
{
// tokenize the prompt
output_tokens = ::gpt_tokenize(vocab, str_to_tokenize);
}
}
static float LowestLogit(const std::vector<float> & logits)
{
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static float LowestLogit(const float *logits, size_t size)
{
if (size == 0) {
// Handle the case of an empty array
return 0.0;
if (!gammarstr.empty()) {
parsed_grammar = grammar_parser::parse(gammarstr.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
printf("\nIgnored invalid grammar sampler.");
return;
}
grammar_parser::print_grammar(stderr, parsed_grammar);
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
int topid = std::min_element(logits, logits + size) - logits;
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static std::string RemoveBell(const std::string & input) //removes the bell character
{
std::string word2;
std::remove_copy(input.begin(), input.end(), std::back_inserter(word2), '\a');
return word2;
}
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta)
@ -522,6 +670,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}
n_vocab = llama_v2_n_vocab(llama_ctx_v2);
//determine mem per token
const std::vector<int> tmp = {1, 2, 3, 4};
llama_v2_eval(llama_ctx_v2, tmp.data(), tmp.size(), 0, params.n_threads);
@ -587,6 +737,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}
n_vocab = llama_v3_n_vocab(llama_ctx_v3);
//determine mem per token
const std::vector<int> tmp = {1, 2, 3, 4};
auto er = llama_v3_eval(llama_ctx_v3, tmp.data(), tmp.size(), 0, params.n_threads);
@ -663,6 +815,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}
n_vocab = llama_n_vocab(llama_ctx_v4);
//determine mem per token
const std::vector<int> tmp = {1, 2, 3, 4};
auto er = llama_eval(llama_ctx_v4, tmp.data(), tmp.size(), 0, params.n_threads);
@ -720,6 +874,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nRWKV Vocab: %u\n", vocabsiz);
logits.resize(vocabsiz);
n_vocab = vocab.id_to_token.size(); //handled seperately
if (file_format == FileFormat::RWKV_1)
{
n_batch = 1;
@ -790,6 +946,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nTensor Transposition Detected! Retrying GPT-2 model loading...");
return res;
}
n_vocab = gpt2_ctx_v1.hparams.n_vocab;
// determine the required inference memory per token:
legacy_gpt2_eval(gpt2_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
return ModelLoadResult::SUCCESS;
@ -809,6 +968,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nTensor Transposition Detected! Retrying GPT-2 model loading...");
return res;
}
n_vocab = gpt2_ctx_v3.hparams.n_vocab;
// determine the required inference memory per token:
gpt2_eval(gpt2_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
return ModelLoadResult::SUCCESS;
@ -829,6 +991,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nTensor Transposition Detected! Retrying GPT-2 model loading...");
return res;
}
n_vocab = gpt2_ctx_v2.hparams.n_vocab;
// determine the required inference memory per token:
gpt2_v2_eval(gpt2_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
return ModelLoadResult::SUCCESS;
@ -847,6 +1012,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\nTensor Transposition Detected! Retrying GPT-J model loading...");
return res;
}
n_vocab = gptj_ctx_v1.hparams.n_vocab;
// determine the required inference memory per token:
legacy_gptj_eval(gptj_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
@ -876,6 +1044,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return loadresult;
}
n_vocab = gptj_ctx_v3.hparams.n_vocab;
// determine the required inference memory per token:
gptj_eval(gptj_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
@ -912,6 +1082,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return loadresult;
}
n_vocab = gptj_ctx_v2.hparams.n_vocab;
// determine the required inference memory per token:
gptj_v2_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
@ -948,6 +1120,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return res;
}
n_vocab = neox_ctx_v3.hparams.n_vocab;
// determine the required inference memory per token:
gpt_neox_eval(neox_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
@ -970,6 +1144,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return res;
}
n_vocab = neox_ctx_v2.hparams.n_vocab;
// determine the required inference memory per token:
gpt_neox_v2_eval(neox_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
@ -1005,6 +1181,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return ModelLoadResult::FAIL;
}
n_vocab = mpt_ctx_v3.hparams.n_vocab;
// determine the required inference memory per token:
mpt_eval(mpt_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token, use_scratch);
return ModelLoadResult::SUCCESS;
@ -1084,6 +1262,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens
std::string grammarstr = inputs.grammar;
load_grammar(grammarstr);
if (params.repeat_last_n < 1)
{
params.repeat_last_n = 1;
@ -1193,59 +1374,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
timer_start();
double time1 = 0, time2 = 0;
int32_t n_vocab = 0;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
n_vocab = llama_v2_n_vocab(llama_ctx_v2);
}
else if(file_format == FileFormat::GGJT_3)
{
n_vocab = llama_v3_n_vocab(llama_ctx_v3);
}
else if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
n_vocab = llama_n_vocab(llama_ctx_v4);
}
else if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2)
{
n_vocab = gptj_ctx_v1.hparams.n_vocab;
}
else if(file_format == FileFormat::GPTJ_3 || file_format==FileFormat::GPTJ_4)
{
n_vocab = gptj_ctx_v2.hparams.n_vocab;
}
else if(file_format==FileFormat::GPTJ_5)
{
n_vocab = gptj_ctx_v3.hparams.n_vocab;
}
else if(file_format == FileFormat::GPT2_1)
{
n_vocab = gpt2_ctx_v1.hparams.n_vocab;
}
else if(file_format == FileFormat::GPT2_2 || file_format==FileFormat::GPT2_3)
{
n_vocab = gpt2_ctx_v2.hparams.n_vocab;
}
else if(file_format==FileFormat::GPT2_4)
{
n_vocab = gpt2_ctx_v3.hparams.n_vocab;
}
else if(file_format == FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5)
{
n_vocab = neox_ctx_v2.hparams.n_vocab;
}
else if( file_format==FileFormat::NEOX_6|| file_format==FileFormat::NEOX_7)
{
n_vocab = neox_ctx_v3.hparams.n_vocab;
}
else if( file_format==FileFormat::MPT_1)
{
n_vocab = mpt_ctx_v3.hparams.n_vocab;
}
else if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
n_vocab = vocab.id_to_token.size(); //handled seperately
if(n_past==0)
{
if(file_format == FileFormat::RWKV_1)
@ -1276,9 +1407,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
}
else
if(n_vocab<=0)
{
printf("Bad format!");
printf("\nWarning! n_vocab is invalid, maybe bad format!");
}
//prepare banned tokens
@ -1459,107 +1591,52 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
unsigned int eosID = 0;
unsigned int eosID = GetEosID(file_format, n_vocab);
float * logitsPtr;
float lowestLogit = 0;
int btsize = banned_token_ids.size();
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
logitsPtr = llama_get_logits(llama_ctx_v4);
eosID = llama_token_eos(llama_ctx_v4);
}
else if(file_format == FileFormat::GGJT_3)
{
logitsPtr = llama_v3_get_logits(llama_ctx_v3);
eosID = llama_v3_token_eos();
}
else
{
logitsPtr = llama_v2_get_logits(llama_ctx_v2);
eosID = llama_v3_token_eos();
}
float lowestLogit = LowestLogit(logitsPtr,n_vocab);
if (!unbanTokens && !inputs.unban_tokens_rt)
{
// set the logit of the eos token (2) to -INF to avoid sampling it
logitsPtr[eosID] = lowestLogit;
}
if(btsize>0)
{
for(int t=0;t<btsize;++t)
{
logitsPtr[banned_token_ids[t]]=lowestLogit;
}
}
lowestLogit = LowestLogit(logitsPtr,n_vocab);
}
else
{
logitsPtr = logits.data();
float lowestLogit = LowestLogit(logits);
if (!unbanTokens && !inputs.unban_tokens_rt)
lowestLogit = LowestLogit(logits);
}
if (!unbanTokens && !inputs.unban_tokens_rt)
{
// set the logit of the eos token to very low to avoid sampling it
logitsPtr[eosID] = lowestLogit;
}
if(btsize>0)
{
for(int t=0;t<btsize;++t)
{
//gpt2 uses negative logits, so we cant zero it
// set the logit of the eos token to minimum to avoid sampling it
if (file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPT2_2 ||
file_format == FileFormat::GPT2_3 ||
file_format == FileFormat::GPT2_4 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::GPTJ_3 ||
file_format == FileFormat::GPTJ_4 ||
file_format == FileFormat::GPTJ_5)
{
eosID = 50256;
if(logits.size() > eosID)
{
logits[eosID] = lowestLogit;
}
else
{
//special case, starcoder models use ID 0 for EOS
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
{
eosID = 0;
logits[eosID] = lowestLogit;
}
}
}
// set the logit of the eos token (0) to minimum to avoid sampling it
if (file_format == FileFormat::RWKV_1 ||
file_format == FileFormat::RWKV_2 ||
file_format == FileFormat::NEOX_1 ||
file_format == FileFormat::NEOX_2 ||
file_format == FileFormat::NEOX_3 ||
file_format == FileFormat::NEOX_4 ||
file_format == FileFormat::NEOX_5 ||
file_format == FileFormat::NEOX_6 ||
file_format == FileFormat::NEOX_7 ||
file_format == FileFormat::MPT_1)
{
eosID = 0;
logits[eosID] = lowestLogit;
}
}
if(btsize>0)
{
for (int t = 0; t < btsize; ++t)
{
logits[banned_token_ids[t]] = lowestLogit;
}
logitsPtr[banned_token_ids[t]]=lowestLogit;
}
}
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order);
params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order, grammar);
if (grammar != nullptr) {
grammar_accept_token(file_format, n_vocab, grammar, id);
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);

View file

@ -63,7 +63,8 @@ class generation_inputs(ctypes.Structure):
("sampler_len", ctypes.c_int),
("unban_tokens_rt", ctypes.c_bool),
("stop_sequence", ctypes.c_char_p * stop_token_max),
("stream_sse", ctypes.c_bool)]
("stream_sse", ctypes.c_bool),
("grammar", ctypes.c_char_p)]
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
@ -277,7 +278,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False):
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=True, stream_sse=False, grammar=''):
global maxctx, args
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
@ -299,6 +300,7 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_
inputs.rep_pen = rep_pen
inputs.rep_pen_range = rep_pen_range
inputs.stream_sse = stream_sse
inputs.grammar = grammar.encode("UTF-8")
inputs.unban_tokens_rt = not use_default_badwordsids
if args.usemirostat and args.usemirostat[0]>0:
inputs.mirostat = int(args.usemirostat[0])
@ -399,7 +401,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', []),
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
stream_sse=stream_flag)
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''))
else:
return generate(prompt=newprompt,
@ -420,7 +423,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', []),
use_default_badwordsids=genparams.get('use_default_badwordsids', True),
stream_sse=stream_flag)
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''))
recvtxt = ""
if stream_flag: