diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 54b3b5521..3449be635 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -65,6 +65,7 @@ static size_t mem_per_token = 0; static std::vector logits; static std::vector smartcontext; static std::vector stop_sequence; +static std::vector top_picks; inline bool IsNanCheck(float f) { @@ -96,11 +97,26 @@ llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng llama_sample_softmax(nullptr, candidates); std::vector probs; probs.reserve(candidates->size); + top_picks.clear(); for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); + probs.push_back(candidates->data[i].p); } + std::discrete_distribution<> dist(probs.begin(), probs.end()); int idx = dist(rng); + + if(debugmode) + { + top_picks.push_back(candidates->data[idx]); + for (size_t i = 0; (i < candidates->size && i<4); ++i) + { + if(i!=idx) + { + top_picks.push_back(candidates->data[i]); + } + } + } + llama_token result = candidates->data[idx].id; return result; } @@ -216,6 +232,22 @@ int mirostat, float mirostat_tau, float mirostat_eta) return id; } +static std::string FileFormatTokenizeID(int id, FileFormat file_format) +{ + if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2) + { + return std::string(llama_v2_token_to_str(llama_ctx_v2, id)); + } + else if (file_format == FileFormat::GGJT_3) + { + return std::string(llama_token_to_str(llama_ctx_v3, id)); + } + else + { + return vocab.id_to_token[id]; + } +} + ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format) { ggml_time_init(); @@ -593,7 +625,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } - generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output) { stop_sequence.clear(); @@ -628,7 +659,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } if (params.top_k < 1) { - params.top_k = 300; //to disable top_k we actually need to increase this value to a very high number + params.top_k = 200; //to disable top_k we actually need to increase this value to a very high number } if (params.seed <= 0) { @@ -795,49 +826,32 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\n"); - if(debugmode) + if (debugmode) { - printf("\n[Debug: Dump Input Tokens, format: %d]\n",file_format); - if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2) - { - for (auto id : embd_inp) - { - printf("'%s (%d)', ",llama_v2_token_to_str(llama_ctx_v2, id),id); - } + std::string outstr = ""; + printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format); - printf("\n\n[Debug: Context Size = %d]\n",current_context_tokens.size()); - for (auto id : current_context_tokens) - { - printf("'%s (%d)', ",llama_v2_token_to_str(llama_ctx_v2, id),id); - } - } - else if (file_format == FileFormat::GGJT_3) + std::string tmp = ""; + for (auto id : embd_inp) { - for (auto id : embd_inp) - { - printf("'%s (%d)', ",llama_token_to_str(llama_ctx_v3, id),id); - } - printf("\n\n[Debug: Context Size = %d]\n",current_context_tokens.size()); - for (auto id : current_context_tokens) - { - printf("'%s (%d)', ",llama_token_to_str(llama_ctx_v3, id),id); - } + tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', "; } - else + ::utreplace(tmp, "\n", "\\n"); + outstr += tmp; + + outstr += "\n\n[Debug: Context Size = " + std::to_string(current_context_tokens.size()) + "]\n"; + tmp = ""; + for (auto id : current_context_tokens) { - for (auto id : embd_inp) - { - printf("'%s (%d)', ",vocab.id_to_token[id].c_str(),id); - } - printf("\n\n[Debug: Context Size = %d]\n",current_context_tokens.size()); - for (auto id : current_context_tokens) - { - printf("'%s (%d)', ",vocab.id_to_token[id].c_str(),id); - } + tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', "; } + ::utreplace(tmp, "\n", "\\n"); + outstr += tmp; + printf(outstr.c_str()); + printf("\n\n"); } - + while (remaining_tokens > 0) { gpt_vocab::id id = 0; @@ -851,7 +865,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\rProcessing Prompt%s (%d / %d tokens)", (blasmode ? " [BLAS]" : ""), input_consumed, embd_inp.size()); } else - { + { printf("\rGenerating (%d / %d tokens)", (1 + params.n_predict - remaining_tokens), params.n_predict); } fflush(stdout); @@ -956,11 +970,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { // set the logit of the eos token (2) to zero to avoid sampling it logits[llama_token_eos()] = 0; - //set logits of opening square bracket to zero. - logits[518] = 0; - logits[29961] = 0; + //set logits of opening square bracket to zero. (disabled as obsolete) + // logits[518] = 0; + // logits[29961] = 0; } + id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng, params.mirostat,params.mirostat_tau,params.mirostat_eta); @@ -970,7 +985,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { if (!unbanTokens) { - // set the logit of the eos token (2) to zero to avoid sampling it + //gpt2 uses negative logits, so we cant zero it + // set the logit of the eos token to minimum to avoid sampling it if ((file_format == FileFormat::GPT2_1 || file_format == FileFormat::GPT2_2 || file_format == FileFormat::GPT2_3 || @@ -981,11 +997,24 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o file_format == FileFormat::GPTJ_4 || file_format == FileFormat::GPTJ_5) && logits.size() > 50256) - { - logits[50256] = (logits[50256] < 0 ? logits[50256] : 0); + { + int topid = std::min_element(logits.begin(),logits.end())-logits.begin(); + logits[50256] = (logits[topid] < 0 ? logits[topid] : 0); } - //gpt2 uses negative logits, so we cant zero it - } + + // set the logit of the eos token (0) to minimum to avoid sampling it + if (file_format == FileFormat::NEOX_1 || + file_format == FileFormat::NEOX_2 || + file_format == FileFormat::NEOX_3 || + file_format == FileFormat::NEOX_4 || + file_format == FileFormat::NEOX_5 || + file_format == FileFormat::NEOX_6 || + file_format == FileFormat::NEOX_7) + { + int topid = std::min_element(logits.begin(),logits.end())-logits.begin(); + logits[0] = (logits[topid] < 0 ? logits[topid] : 0); + } + } id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng, @@ -1026,6 +1055,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o concat_output += vocab.id_to_token[id].c_str(); } } + + if(debugmode && top_picks.size()>0) + { + printf(" ["); + bool firstloop = true; + for (auto & pick : top_picks) + { + if (!firstloop) + { + printf(" "); + } + firstloop = false; + std::string tokenizedstr = FileFormatTokenizeID(pick.id, file_format); + ::utreplace(tokenizedstr, "\n", "\\n"); + printf("(%s %.2f%%)", tokenizedstr.c_str(), pick.p*100); + } + printf("]\n"); + } + for (const auto &matched : stop_sequence) { if (concat_output.find(matched) != std::string::npos)